,
+ 'label': 6,
+ 'pixel_values': tensor([[[ 0.0353, 0.0745, 0.1216, ..., -0.9922, -0.9922, -0.9922],
+ [-0.0196, 0.0667, 0.1294, ..., -0.9765, -0.9843, -0.9922],
+ [ 0.0196, 0.0824, 0.1137, ..., -0.9765, -0.9686, -0.8667],
+ ...,
+ [ 0.0275, 0.0745, 0.0510, ..., -0.1137, -0.1216, -0.0824],
+ [ 0.0667, 0.0824, 0.0667, ..., -0.0588, -0.0745, -0.0980],
+ [ 0.0353, 0.0353, 0.0431, ..., -0.0039, -0.0039, -0.0588]],
+
+ [[ 0.2078, 0.2471, 0.2863, ..., -0.9451, -0.9373, -0.9451],
+ [ 0.1608, 0.2471, 0.3098, ..., -0.9373, -0.9451, -0.9373],
+ [ 0.2078, 0.2706, 0.3020, ..., -0.9608, -0.9373, -0.8275],
+ ...,
+ [-0.0353, 0.0118, -0.0039, ..., -0.2392, -0.2471, -0.2078],
+ [ 0.0196, 0.0353, 0.0196, ..., -0.1843, -0.2000, -0.2235],
+ [-0.0118, -0.0039, -0.0039, ..., -0.0980, -0.0980, -0.1529]],
+
+ [[ 0.3961, 0.4431, 0.4980, ..., -0.9216, -0.9137, -0.9216],
+ [ 0.3569, 0.4510, 0.5216, ..., -0.9059, -0.9137, -0.9137],
+ [ 0.4118, 0.4745, 0.5216, ..., -0.9137, -0.8902, -0.7804],
+ ...,
+ [-0.2314, -0.1922, -0.2078, ..., -0.4196, -0.4275, -0.3882],
+ [-0.1843, -0.1686, -0.2000, ..., -0.3647, -0.3804, -0.4039],
+ [-0.1922, -0.1922, -0.1922, ..., -0.2941, -0.2863, -0.3412]]])}
+```
+
+Este es el aspecto de la imagen despuƩs de preprocesarla. Como era de esperar por las transformaciones aplicadas, la imagen ha sido recortada aleatoriamente y sus propiedades de color son diferentes.
+
+```py
+>>> import numpy as np
+>>> import matplotlib.pyplot as plt
+
+>>> img = dataset[0]["pixel_values"]
+>>> plt.imshow(img.permute(1, 2, 0))
+```
+
+![preprocessed_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/preprocessed_image.png)
+
+## Multimodal
+
+Para las tareas multimodales utilizarĆ”s una combinaciĆ³n de todo lo que has aprendido hasta ahora y aplicarĆ”s tus habilidades a una tarea de reconocimiento automĆ”tico de voz (ASR). Esto significa que necesitarĆ”s un:
+
+* Extractor de caracterĆsticas para preprocesar los datos de audio.
+* Un tokenizador para procesar el texto.
+
+Volvamos al dataset [LJ Speech](https://huggingface.co/datasets/lj_speech):
+
+```py
+>>> from datasets import load_dataset
+
+>>> lj_speech = load_dataset("lj_speech", split="train")
+```
+
+Suponiendo que te interesan principalmente las columnas `audio` y `texto`, elimina las demƔs columnas:
+
+```py
+>>> lj_speech = lj_speech.map(remove_columns=["file", "id", "normalized_text"])
+```
+
+Ahora echa un vistazo a las columnas `audio` y `texto`:
+
+```py
+>>> lj_speech[0]["audio"]
+{'array': array([-7.3242188e-04, -7.6293945e-04, -6.4086914e-04, ...,
+ 7.3242188e-04, 2.1362305e-04, 6.1035156e-05], dtype=float32),
+ 'path': '/root/.cache/huggingface/datasets/downloads/extracted/917ece08c95cf0c4115e45294e3cd0dee724a1165b7fc11798369308a465bd26/LJSpeech-1.1/wavs/LJ001-0001.wav',
+ 'sampling_rate': 22050}
+
+>>> lj_speech[0]["text"]
+'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition'
+```
+
+Recuerda la secciĆ³n anterior sobre el procesamiento de datos de audio, siempre debes [volver a muestrear](preprocessing#audio) la tasa de muestreo de tus datos de audio para que coincida con la tasa de muestreo del dataset utilizado para preentrenar un modelo:
+
+```py
+>>> lj_speech = lj_speech.cast_column("audio", Audio(sampling_rate=16_000))
+```
+
+### Processor
+
+Un processor combina un extractor de caracterĆsticas y un tokenizador. Cargue un procesador con [`AutoProcessor.from_pretrained]:
+
+```py
+>>> from transformers import AutoProcessor
+
+>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
+```
+
+1. Crea una funciĆ³n para procesar los datos de audio en `input_values`, y tokeniza el texto en `labels`. Estas son las entradas del modelo:
+
+```py
+>>> def prepare_dataset(example):
+... audio = example["audio"]
+
+... example["input_values"] = processor(audio["array"], sampling_rate=16000)
+
+... with processor.as_target_processor():
+... example["labels"] = processor(example["text"]).input_ids
+... return example
+```
+
+2. Aplica la funciĆ³n `prepare_dataset` a una muestra:
+
+```py
+>>> prepare_dataset(lj_speech[0])
+```
+
+Observa que el mƩtodo processor ha aƱadido `input_values` y `labels`. La tasa de muestreo tambiƩn se ha reducido correctamente a 16kHz.
+
+Genial, ahora deberĆas ser capaz de preprocesar datos para cualquier modalidad e incluso combinar diferentes modalidades. En el siguiente tutorial, aprenderĆ”s aplicar fine tuning a un modelo en tus datos reciĆ©n preprocesados.
+
+## Todo lo que siempre quisiste saber sobre el padding y el truncamiento
+
+Hemos visto los comandos que funcionarĆ”n para la mayorĆa de los casos (hacer pad a tu batch teniendo en cuenta la longitud de la frase mĆ”xima y
+truncar a la longitud mƔxima que el modelo puede aceptar). Sin embargo, la API admite mƔs estrategias si las necesitas. Los
+tres argumentos que necesitas conocer para ello son `padding`, `truncation` y `max_length`.
+
+- `padding` controla el aplicarme padding al texto. Puede ser un booleano o una cadena que debe ser:
+
+ - `True` o `'longest'` para aplicar el pad hasta la secuencia mĆ”s larga del batch (no apliques el padding si sĆ³lo se proporcionas
+ una sola secuencia).
+ - `'max_length'` para aplicar el pad hasta la longitud especificada por el argumento `max_length` o la longitud mƔxima aceptada
+ por el modelo si no le proporcionas `longitud_mĆ”xima` (`longitud_mĆ”xima=None`). Si sĆ³lo le proporcionas una Ćŗnica secuencia
+ se le aplicarĆ” el padding.
+ `False` o `'do_not_pad'` para no aplicar pad a las secuencias. Como hemos visto antes, este es el comportamiento por
+ defecto.
+
+- `truncation` controla el truncamiento. Puede ser un booleano o una string que debe ser:
+
+ - `True` o `'longest_first'` truncan hasta la longitud mƔxima especificada por el argumento `max_length` o
+ la longitud mƔxima aceptada por el modelo si no le proporcionas `max_length` (`max_length=None`). Esto
+ truncarƔ token por token, eliminando un token de la secuencia mƔs larga del par hasta alcanzar la longitud
+ adecuada.
+ - `'only_second'` trunca hasta la longitud mƔxima especificada por el argumento `max_length` o la
+ longitud mĆ”xima aceptada por el modelo si no le proporcionas `max_length` (`max_length=None`). Esto sĆ³lo truncarĆ”
+ la segunda frase de un par si le proporcionas un par de secuencias (o un batch de pares de secuencias).
+ - `'only_first'` trunca hasta la longitud mƔxima especificada por el argumento `max_length` o la longitud mƔxima
+ aceptada por el modelo si no se proporciona `max_length` (`max_length=None`). Esto sĆ³lo truncarĆ”
+ la primera frase de un par si se proporciona un par de secuencias (o un lote de pares de secuencias).
+ - `False` o `'do_not_truncate'` para no truncar las secuencias. Como hemos visto antes, este es el comportamiento
+ por defecto.
+
+- `max_length` para controlar la longitud del padding/truncamiento. Puede ser un nĆŗmero entero o `None`, en cuyo caso
+serĆ” por defecto la longitud mĆ”xima que el modelo puede aceptar. Si el modelo no tiene una longitud mĆ”xima de entrada especĆfica, el
+padding/truncamiento a `longitud_mƔxima` se desactiva.
+
+A continuaciĆ³n te mostramos en una tabla que resume la forma recomendada de configurar el padding y el truncamiento. Si utilizas un par de secuencias de entrada en
+algunos de los siguientes ejemplos, puedes sustituir `truncation=True` por una `STRATEGY` seleccionada en
+`['only_first', 'only_second', 'longest_first']`, es decir, `truncation='only_second'` o `truncation= 'longest_first'` para controlar cĆ³mo se trunquen ambas secuencias del par como lo has detallado anteriormente.
+
+| Truncation | Padding | Instrucciones |
+|--------------------------------------|-----------------------------------|---------------------------------------------------------------------------------------------|
+| no truncation | no padding | `tokenizer(batch_sentences)` |
+| | padding secuencia max del batch | `tokenizer(batch_sentences, padding=True)` or |
+| | | `tokenizer(batch_sentences, padding='longest')` |
+| | padding long max de input model | `tokenizer(batch_sentences, padding='max_length')` |
+| | padding a una long especifica | `tokenizer(batch_sentences, padding='max_length', max_length=42)` |
+| truncation long max del input model | no padding | `tokenizer(batch_sentences, truncation=True)` or |
+| | | `tokenizer(batch_sentences, truncation=STRATEGY)` |
+| | padding secuencia max del batch | `tokenizer(batch_sentences, padding=True, truncation=True)` or |
+| | | `tokenizer(batch_sentences, padding=True, truncation=STRATEGY)` |
+| | padding long max de input model | `tokenizer(batch_sentences, padding='max_length', truncation=True)` or |
+| | | `tokenizer(batch_sentences, padding='max_length', truncation=STRATEGY)` |
+| | padding a una long especifica | Not possible |
+| truncationa una long especifica | no padding | `tokenizer(batch_sentences, truncation=True, max_length=42)` or |
+| | | `tokenizer(batch_sentences, truncation=STRATEGY, max_length=42)` |
+| | padding secuencia max del batch | `tokenizer(batch_sentences, padding=True, truncation=True, max_length=42)` or |
+| | | `tokenizer(batch_sentences, padding=True, truncation=STRATEGY, max_length=42)` |
+| | padding long max de input model | Not possible |
+| | padding a una long especifica | `tokenizer(batch_sentences, padding='max_length', truncation=True, max_length=42)` or |
+| | | `tokenizer(batch_sentences, padding='max_length', truncation=STRATEGY, max_length=42)` |
+
+
+
+
+
+
+
+
diff --git a/docs/source/es/quicktour.mdx b/docs/source/es/quicktour.mdx
index e9fb764a90e610..9de9e9af4b6e55 100644
--- a/docs/source/es/quicktour.mdx
+++ b/docs/source/es/quicktour.mdx
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Quick tour
+# Tour rƔpido
[[open-in-colab]]
@@ -18,7 +18,7 @@ specific language governing permissions and limitations under the License.
-Todos los ejemplos de cĆ³digo presentados en la documentaciĆ³n tienen un botĆ³n arriba a la izquierda para elegir entre Pytorch y TensorFlow.
+Todos los ejemplos de cĆ³digo presentados en la documentaciĆ³n tienen un botĆ³n arriba a la derecha para elegir si quieres ocultar o mostrar el cĆ³digo en Pytorch o TensorFlow.
Si no fuese asĆ, se espera que el cĆ³digo funcione para ambos backends sin ningĆŗn cambio.
@@ -32,23 +32,23 @@ Si no fuese asĆ, se espera que el cĆ³digo funcione para ambos backends sin ning
El [`pipeline`] soporta muchas tareas comunes listas para usar:
**Texto**:
-* AnƔlisis de Sentimientos: clasifica la polaridad de un texto dado.
-* GeneraciĆ³n de texto (solo en inglĆ©s): genera texto a partir de un input dado.
-* Name entity recognition (NER): etiqueta cada palabra con la entidad que representa (persona, fecha, ubicaciĆ³n, etc.).
-* Responder preguntas: extrae la respuesta del contexto dado un contexto y una pregunta.
-* Fill-mask: rellena el espacio faltante dado un texto con palabras enmascaradas.
-* Summarization: genera un resumen de una secuencia larga de texto o un documento.
-* TraducciĆ³n: traduce un texto a otro idioma.
-* ExtracciĆ³n de caracterĆsticas: crea una representaciĆ³n tensorial del texto.
+* AnƔlisis de Sentimiento (Sentiment Analysis, en inglƩs): clasifica la polaridad de un texto dado.
+* GeneraciĆ³n de Texto (Text Generation, en inglĆ©s): genera texto a partir de un input dado.
+* Reconocimiento de Entidades (Name Entity Recognition o NER, en inglĆ©s): etiqueta cada palabra con la entidad que representa (persona, fecha, ubicaciĆ³n, etc.).
+* Responder Preguntas (Question answering, en inglƩs): extrae la respuesta del contexto dado un contexto y una pregunta.
+* Rellenar MƔscara (Fill-mask, en inglƩs): rellena el espacio faltante dado un texto con palabras enmascaradas.
+* Resumir (Summarization, en inglƩs): genera un resumen de una secuencia larga de texto o un documento.
+* TraducciĆ³n (Translation, en inglĆ©s): traduce un texto a otro idioma.
+* ExtracciĆ³n de CaracterĆsticas (Feature Extraction, en inglĆ©s): crea una representaciĆ³n tensorial del texto.
**Imagen**:
-* ClasificaciĆ³n de imĆ”genes: clasifica una imagen.
-* SegmentaciĆ³n de imĆ”genes: clasifica cada pixel de una imagen.
-* DetecciĆ³n de objetos: detecta objetos dentro de una imagen.
+* ClasificaciĆ³n de ImĆ”genes (Image Classification, en inglĆ©s): clasifica una imagen.
+* SegmentaciĆ³n de ImĆ”genes (Image Segmentation, en inglĆ©s): clasifica cada pixel de una imagen.
+* DetecciĆ³n de Objetos (Object Detection, en inglĆ©s): detecta objetos dentro de una imagen.
**Audio**:
-* ClasificaciĆ³n de audios: asigna una etiqueta a un segmento de audio.
-* Automatic speech recognition (ASR): transcribe datos de audio a un texto.
+* ClasificaciĆ³n de Audios (Audio Classification, en inglĆ©s): asigna una etiqueta a un segmento de audio.
+* Reconocimiento de Voz AutomƔtico (Automatic Speech Recognition o ASR, en inglƩs): transcribe datos de audio a un texto.
@@ -80,25 +80,17 @@ Importa [`pipeline`] y especifica la tarea que deseas completar:
```py
>>> from transformers import pipeline
->>> classifier = pipeline("sentiment-analysis")
+>>> clasificador = pipeline("sentiment-analysis", model="pysentimiento/robertuito-sentiment-analysis")
```
-El pipeline descarga y almacena en cachƩ un [modelo preentrenado](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) por defecto y tokeniza para anƔlisis de sentimiento. Ahora puedes usar `classifier` en tu texto objetivo:
+El pipeline descarga y almacena en cachĆ© el [modelo preentrenado](https://huggingface.co/pysentimiento/robertuito-sentiment-analysis) y tokeniza para anĆ”lisis de sentimiento. Si no hubieramos elegido un modelo el pipeline habrĆa elegido uno por defecto. Ahora puedes usar `clasificador` en tu texto objetivo:
```py
->>> classifier("We are very happy to show you the š¤ Transformers library.")
-[{'label': 'POSITIVE', 'score': 0.9998}]
+>>> clasificador("Estamos muy felices de mostrarte la biblioteca de š¤ Transformers.")
+[{'label': 'POS', 'score': 0.9916}]
```
-Para mƔs de un enunciado entrega una lista de frases al [`pipeline`] que devolverƔ una lista de diccionarios:
-
-```py
->>> results = classifier(["We are very happy to show you the š¤ Transformers library.", "We hope you don't hate it."])
->>> for result in results:
-... print(f"label: {result['label']}, with score: {round(result['score'], 4)}")
-label: POSITIVE, with score: 0.9998
-label: NEGATIVE, with score: 0.5309
-```
+Para mƔs de un enunciado, entrega una lista al [`pipeline`] que devolverƔ una lista de diccionarios:
El [`pipeline`] tambiĆ©n puede iterar sobre un dataset entero. Comienza instalando la biblioteca [š¤ Datasets](https://huggingface.co/docs/datasets/):
@@ -112,7 +104,9 @@ Crea un [`pipeline`] con la tarea que deseas resolver y el modelo que quieres us
>>> import torch
>>> from transformers import pipeline
->>> speech_recognizer = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=0)
+>>> reconocedor_de_voz = pipeline(
+... "automatic-speech-recognition", model="jonatasgrosman/wav2vec2-large-xlsr-53-spanish", device=0
+... )
```
A continuaciĆ³n, carga el dataset (ve š¤ Datasets [Quick Start](https://huggingface.co/docs/datasets/quickstart.html) para mĆ”s detalles) sobre el que quisieras iterar. Por ejemplo, vamos a cargar el dataset [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14):
@@ -120,29 +114,29 @@ A continuaciĆ³n, carga el dataset (ve š¤ Datasets [Quick Start](https://huggin
```py
>>> from datasets import load_dataset, Audio
->>> dataset = load_dataset("PolyAI/minds14", name="en-US", split="train") # doctest: +IGNORE_RESULT
+>>> dataset = load_dataset("PolyAI/minds14", name="es-ES", split="train") # doctest: +IGNORE_RESULT
```
-Debemos asegurarnos de que la frecuencia de muestreo del conjunto de datos coincide con la frecuencia de muestreo con la que se entrenĆ³ `facebook/wav2vec2-base-960h`.
+Debemos asegurarnos de que la frecuencia de muestreo del conjunto de datos coincide con la frecuencia de muestreo con la que se entrenĆ³ `jonatasgrosman/wav2vec2-large-xlsr-53-spanish`.
```py
->>> dataset = dataset.cast_column("audio", Audio(sampling_rate=speech_recognizer.feature_extractor.sampling_rate))
+>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=reconocedor_de_voz.feature_extractor.sampling_rate))
```
-Los archivos de audio se cargan y remuestrean automƔticamente cuando se llama a la columna `"audio"`.
-Extraigamos las matrices de forma de onda cruda de las primeras 4 muestras y pasƩmosla como una lista al pipeline:
+Los archivos de audio se cargan y remuestrean automƔticamente cuando llamamos a la columna `"audio"`.
+Extraigamos las matrices de onda cruda (raw waveform, en inglƩs) de las primeras 4 muestras y pasƩmosla como una lista al pipeline:
```py
->>> result = speech_recognizer(dataset[:4]["audio"])
->>> print([d["text"] for d in result])
-['I WOULD LIKE TO SET UP A JOINT ACCOUNT WITH MY PARTNER HOW DO I PROCEED WITH DOING THAT', "FONDERING HOW I'D SET UP A JOIN TO HET WITH MY WIFE AND WHERE THE AP MIGHT BE", "I I'D LIKE TOY SET UP A JOINT ACCOUNT WITH MY PARTNER I'M NOT SEEING THE OPTION TO DO IT ON THE APSO I CALLED IN TO GET SOME HELP CAN I JUST DO IT OVER THE PHONE WITH YOU AND GIVE YOU THE INFORMATION OR SHOULD I DO IT IN THE AP AND I'M MISSING SOMETHING UQUETTE HAD PREFERRED TO JUST DO IT OVER THE PHONE OF POSSIBLE THINGS", 'HOW DO I TURN A JOIN A COUNT']
+>>> resultado = reconocedor_de_voz(dataset[:4]["audio"])
+>>> print([d["text"] for d in resultado])
+['ahora buenas e a ver tengo un problema como vuestra aplicaciĆ³n resulta que que quiero hacer una transferencia bancaria a una cuenta conocida pero me da error la aplicaciĆ³n a ver que a ver que puede ser', 'la aplicaciĆ³n no cargue salda de mi nueva cuenta', 'hola tengo un problema con la aplicaciĆ³n no carga y y tampoco veo que carga el saldo de mi cuenta nueva dice que la aplicaciĆ³n estĆ” siendo reparada y ahora no puedo aceder a mi cuenta no necesito inmediatamente', 'ora buena la aplicaciĆ³n no se carga la viladad no carga el saldo de mi cuenta nueva dice que la villadenta siendo reparada y oro no puede hacer a mi cuenta']
```
Para un dataset mĆ”s grande, donde los inputs son de mayor tamaƱo (como en habla/audio o visiĆ³n), querrĆ”s pasar un generador en lugar de una lista que carga todos los inputs en memoria. Ve la [documentaciĆ³n del pipeline](./main_classes/pipelines) para mĆ”s informaciĆ³n.
-### Use otro modelo y otro tokenizador en el pipeline
+### Usa otro modelo y otro tokenizador en el pipeline
-El [`pipeline`] puede adaptarse a cualquier modelo del [Model Hub](https://huggingface.co/models) haciendo mĆ”s fĆ”cil adaptar el [`pipeline`] para otros casos de uso. Por ejemplo, si quisieras un modelo capaz de manejar texto en francĆ©s, usa los tags en el Model Hub para filtrar entre los modelos apropiados. El resultado mejor filtrado devuelve un [modelo BERT](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) multilingual fine-tuned para el anĆ”lisis de sentimiento. Genial, Ā”vamos a usar este modelo!
+El [`pipeline`] puede acomodarse a cualquier modelo del [Model Hub](https://huggingface.co/models) haciendo mĆ”s fĆ”cil adaptar el [`pipeline`] para otros casos de uso. Por ejemplo, si quisieras un modelo capaz de manejar texto en francĆ©s, usa los tags en el Model Hub para filtrar entre los modelos apropiados. El resultado mejor filtrado devuelve un [modelo BERT](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) multilingual fine-tuned para el anĆ”lisis de sentimiento. Genial, Ā”vamos a usar este modelo!
```py
>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
@@ -188,7 +182,7 @@ Si no pudieras encontrar el modelo para tu caso respectivo de uso necesitarƔs a
-Debajo del capĆ³, las clases [`AutoModelForSequenceClassification`] y [`AutoTokenizer`] trabajan juntas para dar poder al [`pipeline`]. Una [AutoClass](./model_doc/auto) es un atajo que automĆ”ticamente recupera la arquitectura de un modelo preentrenado con su nombre o el path. SĆ³lo necesitarĆ”s seleccionar el `AutoClass` apropiado para tu tarea y tu tokenizador asociado con [`AutoTokenizer`].
+Por debajo, las clases [`AutoModelForSequenceClassification`] y [`AutoTokenizer`] trabajan juntas para dar poder al [`pipeline`]. Una [AutoClass](./model_doc/auto) es un atajo que automĆ”ticamente recupera la arquitectura de un modelo preentrenado con su nombre o el path. SĆ³lo necesitarĆ”s seleccionar el `AutoClass` apropiado para tu tarea y tu tokenizador asociado con [`AutoTokenizer`].
Regresemos a nuestro ejemplo y veamos cĆ³mo puedes usar el `AutoClass` para reproducir los resultados del [`pipeline`].
@@ -201,8 +195,8 @@ Carga un tokenizador con [`AutoTokenizer`]:
```py
>>> from transformers import AutoTokenizer
->>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
->>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+>>> nombre_del_modelo = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> tokenizer = AutoTokenizer.from_pretrained(nombre_del_modelo)
```
DespuĆ©s, el tokenizador convierte los tokens a nĆŗmeros para construir un tensor que servirĆ” como input para el modelo. Esto es conocido como el *vocabulario* del modelo.
@@ -210,11 +204,11 @@ DespuĆ©s, el tokenizador convierte los tokens a nĆŗmeros para construir un tenso
Pasa tu texto al tokenizador:
```py
->>> encoding = tokenizer("We are very happy to show you the š¤ Transformers library.")
+>>> encoding = tokenizer("Estamos muy felices de mostrarte la biblioteca de š¤ Transformers.")
>>> print(encoding)
-{'input_ids': [101, 11312, 10320, 12495, 19308, 10114, 11391, 10855, 10103, 100, 58263, 13299, 119, 102],
- 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
+{'input_ids': [101, 10602, 14000, 13653, 43353, 10107, 10102, 47201, 10218, 10106, 18283, 10102, 100, 58263, 119, 102],
+ 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
```
El tokenizador devolverĆ” un diccionario conteniendo:
@@ -342,7 +336,7 @@ Los outputs del modelo tambiƩn se comportan como tuplas o diccionarios (e.g., p
-Una vez que tu modelo estƩ fine-tuned puedes guardarlo con tu tokenizador usando [`PreTrainedModel.save_pretrained`]:
+Una vez que se haya hecho fine-tuning a tu modelo puedes guardarlo con tu tokenizador usando [`PreTrainedModel.save_pretrained`]:
```py
>>> pt_save_directory = "./pt_save_pretrained"
@@ -359,7 +353,7 @@ Cuando quieras usar el modelo otra vez cƔrgalo con [`PreTrainedModel.from_pretr
-Una vez que tu modelo estƩ fine-tuned puedes guardarlo con tu tokenizador usando [`TFPreTrainedModel.save_pretrained`]:
+Una vez que se haya hecho fine-tuning a tu modelo puedes guardarlo con tu tokenizador usando [`TFPreTrainedModel.save_pretrained`]:
```py
>>> tf_save_directory = "./tf_save_pretrained"
@@ -375,7 +369,7 @@ Cuando quieras usar el modelo otra vez cƔrgalo con [`TFPreTrainedModel.from_pre
-Una caracterĆstica particularmente cool de š¤ Transformers es la habilidad de guardar el modelo y cargarlo como un modelo de PyTorch o TensorFlow. El parĆ”metro `from_pt` o `from_tf` puede convertir el modelo de un framework al otro:
+Una caracterĆstica particularmente interesante de š¤ Transformers es la habilidad de guardar el modelo y cargarlo como un modelo de PyTorch o TensorFlow. El parĆ”metro `from_pt` o `from_tf` puede convertir el modelo de un framework al otro:
diff --git a/docs/source/es/sagemaker.mdx b/docs/source/es/sagemaker.mdx
new file mode 100644
index 00000000000000..491d93e10d4d14
--- /dev/null
+++ b/docs/source/es/sagemaker.mdx
@@ -0,0 +1,25 @@
+
+
+# Ejecutar el entrenamiento en Amazon SageMaker
+
+La documentaciĆ³n ha sido trasladada a [hf.co/docs/sagemaker](https://huggingface.co/docs/sagemaker). Esta pĆ”gina serĆ” eliminada en `transformers` 5.0.
+
+### Tabla de contenido
+
+- [Entrenar modelos de Hugging Face en Amazon SageMaker con SageMaker Python SDK](https://huggingface.co/docs/sagemaker/train)
+- [Desplegar modelos de Hugging Face en Amazon SageMaker con SageMaker Python SDK](https://huggingface.co/docs/sagemaker/inference)
+- [Preguntas Frecuentes](https://huggingface.co/docs/sagemaker/faq)
diff --git a/docs/source/es/tasks/image_classification.mdx b/docs/source/es/tasks/image_classification.mdx
new file mode 100644
index 00000000000000..9b8b03207d0822
--- /dev/null
+++ b/docs/source/es/tasks/image_classification.mdx
@@ -0,0 +1,169 @@
+
+
+# ClasificaciĆ³n de imĆ”genes
+
+
+
+La clasificaciĆ³n de imĆ”genes asigna una etiqueta o clase a una imagen. A diferencia de la clasificaciĆ³n de texto o audio, las entradas son los valores de los pĆxeles que representan una imagen. La clasificaciĆ³n de imĆ”genes tiene muchos usos, como la detecciĆ³n de daƱos tras una catĆ”strofe, el control de la salud de los cultivos o la bĆŗsqueda de signos de enfermedad en imĆ”genes mĆ©dicas.
+
+Esta guĆa te mostrarĆ” como hacer fine-tune al [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) en el dataset [Food-101](https://huggingface.co/datasets/food101) para clasificar un alimento en una imagen.
+
+
+
+Consulta la [pĆ”gina de la tarea](https://huggingface.co/tasks/audio-classification) de clasificaciĆ³n de imĆ”genes para obtener mĆ”s informaciĆ³n sobre sus modelos, datasets y mĆ©tricas asociadas.
+
+
+
+## Carga el dataset Food-101
+
+Carga solo las primeras 5000 imĆ”genes del dataset Food-101 de la biblioteca š¤ de Datasets ya que es bastante grande:
+
+```py
+>>> from datasets import load_dataset
+
+>>> food = load_dataset("food101", split="train[:5000]")
+```
+
+Divide el dataset en un train y un test set:
+
+```py
+>>> food = food.train_test_split(test_size=0.2)
+```
+
+A continuaciĆ³n, observa un ejemplo:
+
+```py
+>>> food["train"][0]
+{'image': ,
+ 'label': 79}
+```
+
+El campo `image` contiene una imagen PIL, y cada `label` es un nĆŗmero entero que representa una clase. Crea un diccionario que asigne un nombre de label a un entero y viceversa. El mapeo ayudarĆ” al modelo a recuperar el nombre de label a partir del nĆŗmero de la misma:
+
+```py
+>>> labels = food["train"].features["label"].names
+>>> label2id, id2label = dict(), dict()
+>>> for i, label in enumerate(labels):
+... label2id[label] = str(i)
+... id2label[str(i)] = label
+```
+
+Ahora puedes convertir el nĆŗmero de label en un nombre de label para obtener mĆ”s informaciĆ³n:
+
+```py
+>>> id2label[str(79)]
+'prime_rib'
+```
+
+Cada clase de alimento - o label - corresponde a un nĆŗmero; `79` indica una costilla de primera en el ejemplo anterior.
+
+## Preprocesa
+
+Carga el feature extractor de ViT para procesar la imagen en un tensor:
+
+```py
+>>> from transformers import AutoFeatureExtractor
+
+>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
+```
+
+Aplica varias transformaciones de imagen al dataset para hacer el modelo mĆ”s robusto contra el overfitting. En este caso se utilizarĆ” el mĆ³dulo [`transforms`](https://pytorch.org/vision/stable/transforms.html) de torchvision. Recorta una parte aleatoria de la imagen, cambia su tamaƱo y normalĆzala con la media y la desviaciĆ³n estĆ”ndar de la imagen:
+
+```py
+>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
+
+>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
+```
+
+Crea una funciĆ³n de preprocesamiento que aplique las transformaciones y devuelva los `pixel_values` - los inputs al modelo - de la imagen:
+
+```py
+>>> def transforms(examples):
+... examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
+... del examples["image"]
+... return examples
+```
+
+Utiliza el mĆ©todo [`with_transform`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?#datasets.Dataset.with_transform) de š¤ Dataset para aplicar las transformaciones sobre todo el dataset. Las transformaciones se aplican sobre la marcha cuando se carga un elemento del dataset:
+
+```py
+>>> food = food.with_transform(transforms)
+```
+
+Utiliza [`DefaultDataCollator`] para crear un batch de ejemplos. A diferencia de otros data collators en š¤ Transformers, el DefaultDataCollator no aplica un preprocesamiento adicional como el padding.
+
+```py
+>>> from transformers import DefaultDataCollator
+
+>>> data_collator = DefaultDataCollator()
+```
+
+## Entrena
+Carga ViT con [`AutoModelForImageClassification`]. Especifica el nĆŗmero de labels, y pasa al modelo el mapping entre el nĆŗmero de label y la clase de label:
+
+```py
+>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
+
+>>> model = AutoModelForImageClassification.from_pretrained(
+... "google/vit-base-patch16-224-in21k",
+... num_labels=len(labels),
+... id2label=id2label,
+... label2id=label2id,
+... )
+```
+
+
+
+Si no estĆ”s familiarizado con el fine-tuning de un modelo con el [`Trainer`], echa un vistazo al tutorial bĆ”sico [aquĆ](../training#finetune-with-trainer)!
+
+
+
+Al llegar a este punto, solo quedan tres pasos:
+
+1. Define tus hiperparƔmetros de entrenamiento en [`TrainingArguments`]. Es importante que no elimines las columnas que no se utilicen, ya que esto harƔ que desaparezca la columna `image`. Sin la columna `image` no puedes crear `pixel_values`. Establece `remove_unused_columns=False` para evitar este comportamiento.
+2. Pasa los training arguments al [`Trainer`] junto con el modelo, los datasets, tokenizer y data collator.
+3. Llama [`~Trainer.train`] para hacer fine-tune de tu modelo.
+
+```py
+>>> training_args = TrainingArguments(
+... output_dir="./results",
+... per_device_train_batch_size=16,
+... evaluation_strategy="steps",
+... num_train_epochs=4,
+... fp16=True,
+... save_steps=100,
+... eval_steps=100,
+... logging_steps=10,
+... learning_rate=2e-4,
+... save_total_limit=2,
+... remove_unused_columns=False,
+... )
+
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... data_collator=data_collator,
+... train_dataset=food["train"],
+... eval_dataset=food["test"],
+... tokenizer=feature_extractor,
+... )
+
+>>> trainer.train()
+```
+
+
+
+Para ver un ejemplo mĆ”s a profundidad de cĆ³mo hacer fine-tune a un modelo para clasificaciĆ³n de imĆ”genes, echa un vistazo al correspondiente [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
+
+
diff --git a/docs/source/es/tasks/language_modeling.mdx b/docs/source/es/tasks/language_modeling.mdx
new file mode 100644
index 00000000000000..33962a49887012
--- /dev/null
+++ b/docs/source/es/tasks/language_modeling.mdx
@@ -0,0 +1,418 @@
+
+
+# Modelado de lenguaje
+
+El modelado de lenguaje predice palabras en un enunciado. Hay dos formas de modelado de lenguaje.
+
+
+
+El modelado de lenguaje causal predice el siguiente token en una secuencia de tokens, y el modelo solo puede considerar los tokens a la izquierda.
+
+
+
+El modelado de lenguaje por enmascaramiento predice un token enmascarado en una secuencia, y el modelo puede considerar los tokens bidireccionalmente.
+
+Esta guĆa te mostrarĆ” cĆ³mo realizar fine-tuning [DistilGPT2](https://huggingface.co/distilgpt2) para modelos de lenguaje causales y [DistilRoBERTa](https://huggingface.co/distilroberta-base) para modelos de lenguaje por enmascaramiento en el [r/askscience](https://www.reddit.com/r/askscience/) subdataset [ELI5](https://huggingface.co/datasets/eli5).
+
+
+
+Puedes realizar fine-tuning a otras arquitecturas para modelos de lenguaje como [GPT-Neo](https://huggingface.co/EleutherAI/gpt-neo-125M), [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) y [BERT](https://huggingface.co/bert-base-uncased) siguiendo los mismos pasos presentados en esta guĆa!
+
+Mira la [pĆ”gina de tarea](https://huggingface.co/tasks/text-generation) para generaciĆ³n de texto y la [pĆ”gina de tarea](https://huggingface.co/tasks/fill-mask) para modelos de lenguajes por enmascaramiento para obtener mĆ”s informaciĆ³n sobre los modelos, datasets, y mĆ©tricas asociadas.
+
+
+
+## Carga el dataset ELI5
+
+Carga solo los primeros 5000 registros desde la biblioteca š¤ Datasets, dado que es bastante grande:
+
+```py
+>>> from datasets import load_dataset
+
+>>> eli5 = load_dataset("eli5", split="train_asks[:5000]")
+```
+
+Divide este dataset en subdatasets para el entrenamiento y el test:
+
+```py
+eli5 = eli5.train_test_split(test_size=0.2)
+```
+
+Luego observa un ejemplo:
+
+```py
+>>> eli5["train"][0]
+{'answers': {'a_id': ['c3d1aib', 'c3d4lya'],
+ 'score': [6, 3],
+ 'text': ["The velocity needed to remain in orbit is equal to the square root of Newton's constant times the mass of earth divided by the distance from the center of the earth. I don't know the altitude of that specific mission, but they're usually around 300 km. That means he's going 7-8 km/s.\n\nIn space there are no other forces acting on either the shuttle or the guy, so they stay in the same position relative to each other. If he were to become unable to return to the ship, he would presumably run out of oxygen, or slowly fall into the atmosphere and burn up.",
+ "Hope you don't mind me asking another question, but why aren't there any stars visible in this photo?"]},
+ 'answers_urls': {'url': []},
+ 'document': '',
+ 'q_id': 'nyxfp',
+ 'selftext': '_URL_0_\n\nThis was on the front page earlier and I have a few questions about it. Is it possible to calculate how fast the astronaut would be orbiting the earth? Also how does he stay close to the shuttle so that he can return safely, i.e is he orbiting at the same speed and can therefore stay next to it? And finally if his propulsion system failed, would he eventually re-enter the atmosphere and presumably die?',
+ 'selftext_urls': {'url': ['http://apod.nasa.gov/apod/image/1201/freeflyer_nasa_3000.jpg']},
+ 'subreddit': 'askscience',
+ 'title': 'Few questions about this space walk photograph.',
+ 'title_urls': {'url': []}}
+```
+
+Observa que `text` es un subcampo anidado dentro del diccionario `answers`. Cuando preproceses el dataset, deberƔs extraer el subcampo `text` en una columna aparte.
+
+## Preprocesamiento
+
+
+
+Para modelados de lenguaje causales carga el tokenizador DistilGPT2 para procesar el subcampo `text`:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
+```
+
+
+
+Para modelados de lenguaje por enmascaramiento carga el tokenizador DistilRoBERTa, en lugar de DistilGPT2:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
+```
+
+Extrae el subcampo `text` desde su estructura anidado con el mƩtodo [`flatten`](https://huggingface.co/docs/datasets/process.html#flatten):
+
+```py
+>>> eli5 = eli5.flatten()
+>>> eli5["train"][0]
+{'answers.a_id': ['c3d1aib', 'c3d4lya'],
+ 'answers.score': [6, 3],
+ 'answers.text': ["The velocity needed to remain in orbit is equal to the square root of Newton's constant times the mass of earth divided by the distance from the center of the earth. I don't know the altitude of that specific mission, but they're usually around 300 km. That means he's going 7-8 km/s.\n\nIn space there are no other forces acting on either the shuttle or the guy, so they stay in the same position relative to each other. If he were to become unable to return to the ship, he would presumably run out of oxygen, or slowly fall into the atmosphere and burn up.",
+ "Hope you don't mind me asking another question, but why aren't there any stars visible in this photo?"],
+ 'answers_urls.url': [],
+ 'document': '',
+ 'q_id': 'nyxfp',
+ 'selftext': '_URL_0_\n\nThis was on the front page earlier and I have a few questions about it. Is it possible to calculate how fast the astronaut would be orbiting the earth? Also how does he stay close to the shuttle so that he can return safely, i.e is he orbiting at the same speed and can therefore stay next to it? And finally if his propulsion system failed, would he eventually re-enter the atmosphere and presumably die?',
+ 'selftext_urls.url': ['http://apod.nasa.gov/apod/image/1201/freeflyer_nasa_3000.jpg'],
+ 'subreddit': 'askscience',
+ 'title': 'Few questions about this space walk photograph.',
+ 'title_urls.url': []}
+```
+
+Cada subcampo es ahora una columna separada, como lo indica el prefijo `answers`. Observa que `answers.text` es una lista. En lugar de tokenizar cada enunciado por separado, convierte la lista en un string para tokenizarlos conjuntamente.
+
+AsĆ es como puedes crear una funciĆ³n de preprocesamiento para convertir la lista en una cadena y truncar las secuencias para que no superen la longitud mĆ”xima de input de DistilGPT2:
+
+```py
+>>> def preprocess_function(examples):
+... return tokenizer([" ".join(x) for x in examples["answers.text"]], truncation=True)
+```
+
+Usa de š¤ Datasets la funciĆ³n [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) para aplicar la funciĆ³n de preprocesamiento sobre el dataset en su totalidad. Puedes acelerar la funciĆ³n `map` configurando el argumento `batched=True` para procesar mĆŗltiples elementos del dataset a la vez y aumentar la cantidad de procesos con `num_proc`. Elimina las columnas que no necesitas:
+
+```py
+>>> tokenized_eli5 = eli5.map(
+... preprocess_function,
+... batched=True,
+... num_proc=4,
+... remove_columns=eli5["train"].column_names,
+... )
+```
+
+Ahora necesitas una segunda funciĆ³n de preprocesamiento para capturar el texto truncado de cualquier ejemplo demasiado largo para evitar cualquier pĆ©rdida de informaciĆ³n. Esta funciĆ³n de preprocesamiento deberĆa:
+
+- Concatenar todo el texto.
+- Dividir el texto concatenado en trozos mƔs pequeƱos definidos por un `block_size`.
+
+```py
+>>> block_size = 128
+
+
+>>> def group_texts(examples):
+... concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
+... total_length = len(concatenated_examples[list(examples.keys())[0]])
+... result = {
+... k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
+... for k, t in concatenated_examples.items()
+... }
+... result["labels"] = result["input_ids"].copy()
+... return result
+```
+
+Aplica la funciĆ³n `group_texts` sobre todo el dataset:
+
+```py
+>>> lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)
+```
+
+Para modelados de lenguaje causales, usa [`DataCollatorForLanguageModeling`] para crear un lote de ejemplos. Esto tambiĆ©n *rellenarĆ” dinĆ”micamente* tu texto a la dimensiĆ³n del elemento mĆ”s largo del lote para que de esta manera tengan largo uniforme. Si bien es posible rellenar tu texto en la funciĆ³n `tokenizer` mediante el argumento `padding=True`, el rellenado dinĆ”mico es mĆ”s eficiente.
+
+
+
+Puedes usar el token de final de secuencia como el token de relleno y asignar `mlm=False`. Esto usarĆ” los inputs como etiquetas movidas un elemento hacia la derecha:
+
+```py
+>>> from transformers import DataCollatorForLanguageModeling
+
+>>> tokenizer.pad_token = tokenizer.eos_token
+>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+```
+
+Para modelados de lenguaje por enmascaramiento usa el mismo [`DataCollatorForLanguageModeling`] excepto que deberƔs especificar `mlm_probability` para enmascarar tokens aleatoriamente cada vez que iteras sobre los datos.
+
+```py
+>>> from transformers import DataCollatorForLanguageModeling
+
+>>> tokenizer.pad_token = tokenizer.eos_token
+>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
+```
+
+
+Puedes usar el token de final de secuencia como el token de relleno y asignar `mlm=False`. Esto usarĆ” los inputs como etiquetas movidas un elemento hacia la derecha:
+
+```py
+>>> from transformers import DataCollatorForLanguageModeling
+
+>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="tf")
+```
+
+Para modelados de lenguajes por enmascaramiento usa el mismo [`DataCollatorForLanguageModeling`] excepto que deberƔs especificar `mlm_probability` para enmascarar tokens aleatoriamente cada vez que iteras sobre los datos.
+
+```py
+>>> from transformers import DataCollatorForLanguageModeling
+
+>>> data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="tf")
+```
+
+
+
+## Modelado de lenguaje causal
+
+El modelado de lenguaje causal es frecuentemente utilizado para generaciĆ³n de texto. Esta secciĆ³n te muestra cĆ³mo realizar fine-tuning a [DistilGPT2](https://huggingface.co/distilgpt2) para generar nuevo texto.
+
+### Entrenamiento
+
+
+
+Carga DistilGPT2 con [`AutoModelForCausalLM`]:
+
+```py
+>>> from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
+
+>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+```
+
+
+
+Si no estĆ”s familiarizado con el proceso de realizar fine-tuning sobre un modelo con [`Trainer`], considera el tutorial bĆ”sico [aquĆ](../training#finetune-with-trainer)!
+
+
+
+A este punto, solo faltan tres pasos:
+
+1. Definir tus hiperparƔmetros de entrenamiento en [`TrainingArguments`].
+2. Pasarle los argumentos de entrenamiento a [`Trainer`] junto con el modelo, dataset, y el data collator.
+3. Realiza la llamada [`~Trainer.train`] para realizar el fine-tuning sobre tu modelo.
+
+```py
+>>> training_args = TrainingArguments(
+... output_dir="./results",
+... evaluation_strategy="epoch",
+... learning_rate=2e-5,
+... weight_decay=0.01,
+... )
+
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=lm_dataset["train"],
+... eval_dataset=lm_dataset["test"],
+... data_collator=data_collator,
+... )
+
+>>> trainer.train()
+```
+
+
+Para realizar el fine-tuning de un modelo en TensorFlow, comienza por convertir tus datasets al formato `tf.data.Dataset` con [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Especifica los inputs y etiquetas en `columns`, ya sea para mezclar el dataset, tamaƱo de lote, y el data collator:
+
+```py
+>>> tf_train_set = lm_dataset["train"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... dummy_labels=True,
+... shuffle=True,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+
+>>> tf_test_set = lm_dataset["test"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... dummy_labels=True,
+... shuffle=False,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+```
+
+
+
+Si no estĆ”s familiarizado con realizar fine-tuning de tus modelos con Keras, considera el tutorial bĆ”sico [aquĆ](training#finetune-with-keras)!
+
+
+
+Crea la funciĆ³n optimizadora, la tasa de aprendizaje, y algunos hiperparĆ”metros de entrenamiento:
+
+```py
+>>> from transformers import create_optimizer, AdamWeightDecay
+
+>>> optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
+```
+
+Carga DistilGPT2 con [`TFAutoModelForCausalLM`]:
+
+```py
+>>> from transformers import TFAutoModelForCausalLM
+
+>>> model = TFAutoModelForCausalLM.from_pretrained("distilgpt2")
+```
+
+Configura el modelo para entrenamiento con [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
+
+```py
+>>> import tensorflow as tf
+
+>>> model.compile(optimizer=optimizer)
+```
+
+Llama a [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) para realizar el fine-tuning del modelo:
+
+```py
+>>> model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3)
+```
+
+
+
+## Modelado de lenguaje por enmascaramiento
+
+El modelado de lenguaje por enmascaramiento es tambiĆ©n conocido como una tarea de rellenar la mĆ”scara, pues predice un token enmascarado dada una secuencia. Los modelos de lenguaje por enmascaramiento requieren una buena comprensiĆ³n del contexto de una secuencia entera, en lugar de solo el contexto a la izquierda. Esta secciĆ³n te enseƱa como realizar el fine-tuning de [DistilRoBERTa](https://huggingface.co/distilroberta-base) para predecir una palabra enmascarada.
+
+### Entrenamiento
+
+
+
+Carga DistilRoBERTa con [`AutoModelForMaskedlM`]:
+
+```py
+>>> from transformers import AutoModelForMaskedLM
+
+>>> model = AutoModelForMaskedLM.from_pretrained("distilroberta-base")
+```
+
+
+
+Si no estĆ”s familiarizado con el proceso de realizar fine-tuning sobre un modelo con [`Trainer`], considera el tutorial bĆ”sico [aquĆ](../training#finetune-with-trainer)!
+
+
+
+A este punto, solo faltan tres pasos:
+
+1. Definir tus hiperparƔmetros de entrenamiento en [`TrainingArguments`].
+2. Pasarle los argumentos de entrenamiento a [`Trainer`] junto con el modelo, dataset, y el data collator.
+3. Realiza la llamada [`~Trainer.train`] para realizar el fine-tuning de tu modelo.
+
+```py
+>>> training_args = TrainingArguments(
+... output_dir="./results",
+... evaluation_strategy="epoch",
+... learning_rate=2e-5,
+... num_train_epochs=3,
+... weight_decay=0.01,
+... )
+
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=lm_dataset["train"],
+... eval_dataset=lm_dataset["test"],
+... data_collator=data_collator,
+... )
+
+>>> trainer.train()
+```
+
+
+Para realizar el fine-tuning de un modelo en TensorFlow, comienza por convertir tus datasets al formato `tf.data.Dataset` con [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Especifica los inputs y etiquetas en `columns`, ya sea para mezclar el dataset, tamaƱo de lote, y el data collator:
+
+```py
+>>> tf_train_set = lm_dataset["train"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... dummy_labels=True,
+... shuffle=True,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+
+>>> tf_test_set = lm_dataset["test"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... dummy_labels=True,
+... shuffle=False,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+```
+
+
+
+Si no estĆ”s familiarizado con realizar fine-tuning de tus modelos con Keras, considera el tutorial bĆ”sico [aquĆ](training#finetune-with-keras)!
+
+
+
+Crea la funciĆ³n optimizadora, la tasa de aprendizaje, y algunos hiperparĆ”metros de entrenamiento:
+
+```py
+>>> from transformers import create_optimizer, AdamWeightDecay
+
+>>> optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
+```
+
+Carga DistilRoBERTa con [`TFAutoModelForMaskedLM`]:
+
+```py
+>>> from transformers import TFAutoModelForMaskedLM
+
+>>> model = TFAutoModelForCausalLM.from_pretrained("distilroberta-base")
+```
+
+Configura el modelo para entrenamiento con [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
+
+```py
+>>> import tensorflow as tf
+
+>>> model.compile(optimizer=optimizer)
+```
+
+Llama a [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) para realizar el fine-tuning del modelo:
+
+```py
+>>> model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3)
+```
+
+
+
+
+
+Para un ejemplo mĆ”s profundo sobre cĆ³mo realizar el fine-tuning sobre un modelo de lenguaje causal, considera
+[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb)
+o [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
+
+
\ No newline at end of file
diff --git a/docs/source/es/training.mdx b/docs/source/es/training.mdx
index 653e9437ae2ebd..eefe96f9e80d8d 100644
--- a/docs/source/es/training.mdx
+++ b/docs/source/es/training.mdx
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-El uso de un modelo pre-entrenado tiene importantes ventajas. Reduce los costos de computaciĆ³n, la huella de carbono, y te permite utilizar modelos de Ćŗltima generaciĆ³n sin tener que entrenar uno desde cero. š¤ Transformers proporciona acceso a miles de modelos pre-entrenados en una amplia gama de tareas. Cuando utilizas un modelo pre-entrenado, lo entrenas con un dataset especĆfico para tu tarea. Esto se conoce como fine-tuning, una tĆ©cnica de entrenamiento increĆblemente poderosa. En este tutorial haremos fine-tuning a un modelo pre-entrenado con un framework de Deep Learning de tu elecciĆ³n:
+El uso de un modelo pre-entrenado tiene importantes ventajas. Reduce los costos de computaciĆ³n, la huella de carbono y te permite utilizar modelos de Ćŗltima generaciĆ³n sin tener que entrenar uno desde cero.
* Fine-tuning a un modelo pre-entrenado con š¤ Transformers [`Trainer`].
* Fine-tuning a un modelo pre-entrenado en TensorFlow con Keras.
@@ -39,7 +39,7 @@ Comienza cargando el dataset de [Yelp Reviews](https://huggingface.co/datasets/y
'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}
```
-Como ya sabes, necesitas un tokenizador para procesar el texto e incluir una estrategia para el padding y el truncamiento, para manejar cualquier longitud de secuencia variable. Para procesar tu dataset en un solo paso, utiliza el mĆ©todo de š¤ Datasets [`map`](https://huggingface.co/docs/datasets/process.html#map) para aplicar una funciĆ³n de preprocesamiento sobre todo el dataset:
+Como ya sabes, necesitas un tokenizador para procesar el texto e incluir una estrategia para el padding y el truncamiento para manejar cualquier longitud de secuencia variable. Para procesar tu dataset en un solo paso, utiliza el mĆ©todo de š¤ DatasetsĀ mappara aplicar una funciĆ³n de preprocesamiento sobre todo el dataset:
```py
>>> from transformers import AutoTokenizer
@@ -79,7 +79,7 @@ Comienza cargando tu modelo y especifica el nĆŗmero de labels previstas. A parti
-VerƔs una advertencia acerca de que algunos de los pesos pre-entrenados que no estƔn siendo utilizados y que algunos pesos estƔn siendo inicializados al azar.
+VerƔs una advertencia acerca de que algunos de los pesos pre-entrenados no estƔn siendo utilizados y que algunos pesos estƔn siendo inicializados al azar. No te preocupes, esto es completamente normal.
No te preocupes, esto es completamente normal. El head/cabezal pre-entrenado del modelo BERT se descarta y se sustituye por un head de clasificaciĆ³n inicializado aleatoriamente. Puedes aplicar fine-tuning a este nuevo head del modelo en tu tarea de clasificaciĆ³n de secuencias haciendo transfer learning del modelo pre-entrenado.
@@ -98,7 +98,7 @@ Especifica dĆ³nde vas a guardar los checkpoints de tu entrenamiento:
### MĆ©tricas
-El [`Trainer`] no evalĆŗa automĆ”ticamente el rendimiento del modelo durante el entrenamiento. TendrĆ”s que pasarle a [`Trainer`] una funciĆ³n para calcular y hacer un reporte de las mĆ©tricas. La librerĆa de š¤ Datasets proporciona una funciĆ³n de [`accuracy`](https://huggingface.co/metrics/accuracy) simple que puedes cargar con la funciĆ³n `load_metric` (ver este [tutorial](https://huggingface.co/docs/datasets/metrics.html) para mĆ”s informaciĆ³n):
+El [`Trainer`] no evalĆŗa automĆ”ticamente el rendimiento del modelo durante el entrenamiento. TendrĆ”s que pasarle a [`Trainer`] una funciĆ³n para calcular y hacer un reporte de las mĆ©tricas. La biblioteca de š¤ Datasets proporciona una funciĆ³n de [`accuracy`](https://huggingface.co/metrics/accuracy) simple que puedes cargar con la funciĆ³n `load_metric` (ver este [tutorial](https://huggingface.co/docs/datasets/metrics.html) para mĆ”s informaciĆ³n):
```py
>>> import numpy as np
@@ -126,7 +126,7 @@ Si quieres controlar tus mĆ©tricas de evaluaciĆ³n durante el fine-tuning, especi
### Trainer
-Crea un objeto [`Trainer`] con tu modelo, argumentos de entrenamiento, conjuntos de datos de entrenamiento y de prueba, y tu funciĆ³n de evaluaciĆ³n:
+Crea un objeto [`Trainer`] con tu modelo, argumentos de entrenamiento, datasets de entrenamiento y de prueba, y tu funciĆ³n de evaluaciĆ³n:
```py
>>> trainer = Trainer(
@@ -150,7 +150,7 @@ A continuaciĆ³n, aplica fine-tuning a tu modelo llamando [`~transformers.Trainer
-Los modelos de š¤ Transformers tambiĆ©n permiten realizar el entrenamiento en TensorFlow con la API de Keras. SĆ³lo es necesario hacer algunos cambios antes de hacer fine-tuning.
+Los modelos de š¤ Transformers tambiĆ©n permiten realizar el entrenamiento en TensorFlow con la API de Keras. Solo es necesario hacer algunos cambios antes de hacer fine-tuning.
### Convierte el dataset al formato de TensorFlow
@@ -217,7 +217,7 @@ A continuaciĆ³n, compila y aplica fine-tuning a tu modelo con [`fit`](https://ke
-El [`Trainer`] se encarga del ciclo de entrenamiento y permite aplicar fine-tuning a un modelo en una sola lĆnea de cĆ³digo. Para los usuarios que prefieren escribir tu propio ciclo de entrenamiento, tambiĆ©n puedes aplicar fine-tuning a un modelo de š¤ Transformers en PyTorch nativo.
+El [`Trainer`] se encarga del ciclo de entrenamiento y permite aplicar fine-tuning a un modelo en una sola lĆnea de cĆ³digo. Para los que prefieran escribir su propio ciclo de entrenamiento, tambiĆ©n pueden aplicar fine-tuning a un modelo de š¤ Transformers en PyTorch nativo.
En este punto, es posible que necesites reiniciar tu notebook o ejecutar el siguiente cĆ³digo para liberar algo de memoria:
@@ -248,7 +248,7 @@ A continuaciĆ³n, haremos un post-procesamiento manual al `tokenized_dataset` y a
>>> tokenized_datasets.set_format("torch")
```
-A continuaciĆ³n, crea un subconjunto mĆ”s pequeƱo del dataset, como se ha mostrado anteriormente, para acelerar el fine-tuning:
+A continuaciĆ³n, crea un subconjunto mĆ”s pequeƱo del dataset como se ha mostrado anteriormente para acelerar el fine-tuning:
```py
>>> small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
@@ -274,7 +274,7 @@ Carga tu modelo con el nĆŗmero de labels previstas:
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
```
-### Optimiza y progrma el learning rate
+### Optimiza y programa el learning rate
Crea un optimizador y el learning rate para aplicar fine-tuning al modelo. Vamos a utilizar el optimizador [`AdamW`](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) de PyTorch:
@@ -311,11 +311,11 @@ Consigue acceso gratuito a una GPU en la nube si es que no tienes este recurso d
-Genial, Ā”ahora estamos listos entrenar! š„³
+Genial, Ā”ahora podemos entrenar! š„³
### Ciclo de entrenamiento
-Para hacer un seguimiento al progreso del entrenamiento, utiliza la librerĆa [tqdm](https://tqdm.github.io/) para aƱadir una barra de progreso sobre el nĆŗmero de pasos de entrenamiento:
+Para hacer un seguimiento al progreso del entrenamiento, utiliza la biblioteca [tqdm](https://tqdm.github.io/) para aƱadir una barra de progreso sobre el nĆŗmero de pasos de entrenamiento:
```py
>>> from tqdm.auto import tqdm
diff --git a/docs/source/it/_config.py b/docs/source/it/_config.py
new file mode 100644
index 00000000000000..b05ae95c03adab
--- /dev/null
+++ b/docs/source/it/_config.py
@@ -0,0 +1,15 @@
+# docstyle-ignore
+INSTALL_CONTENT = """
+# Installazione di Transformers
+! pip install transformers datasets
+# Per installare dalla fonte invece dell'ultima versione rilasciata, commenta il comando sopra e
+# rimuovi la modalitĆ commento al comando seguente.
+# ! pip install git+https://github.com/huggingface/transformers.git
+"""
+
+notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
+black_avoid_patterns = {
+ "{processor_class}": "FakeProcessorClass",
+ "{model_class}": "FakeModelClass",
+ "{object_class}": "FakeObjectClass",
+}
diff --git a/docs/source/it/_toctree.yml b/docs/source/it/_toctree.yml
new file mode 100644
index 00000000000000..0d91bca083b596
--- /dev/null
+++ b/docs/source/it/_toctree.yml
@@ -0,0 +1,14 @@
+- sections:
+ - local: index
+ title: š¤ Transformers
+ - local: quicktour
+ title: Tour rapido
+ - local: installation
+ title: Installazione
+ title: Iniziare
+- sections:
+ - local: pipeline_tutorial
+ title: Pipeline per l'inferenza
+ - local: autoclass_tutorial
+ title: Carica istanze pre-allenate con AutoClass
+ title: Esercitazione
diff --git a/docs/source/it/autoclass_tutorial.mdx b/docs/source/it/autoclass_tutorial.mdx
new file mode 100644
index 00000000000000..88dd6cad6c4212
--- /dev/null
+++ b/docs/source/it/autoclass_tutorial.mdx
@@ -0,0 +1,119 @@
+
+
+# Carica istanze pre-allenate con AutoClass
+
+Con cosƬ tante architetture Transformer differenti, puĆ² essere sfidante crearne una per il tuo checkpoint. Come parte della filosofia centrale di š¤ Transformers per rendere la libreria facile, semplice e flessibile da utilizzare, una `AutoClass` inferisce e carica automaticamente l'architettura corretta da un dato checkpoint. Il metodo `from_pretrained` ti permette di caricare velocemente un modello pre-allenato per qualsiasi architettura, cosƬ non devi utilizzare tempo e risorse per allenare un modello da zero. Produrre questo codice agnostico ai checkpoint significa che se il tuo codice funziona per un checkpoint, funzionerĆ anche per un altro checkpoint, purchĆ© sia stato allenato per un compito simile, anche se l'architettura ĆØ differente.
+
+
+
+Ricorda, con architettura ci si riferisce allo scheletro del modello e con checkpoint ai pesi di una determinata architettura. Per esempio, [BERT](https://huggingface.co/bert-base-uncased) ĆØ un'architettura, mentre `bert-base-uncased` ĆØ un checkpoint. Modello ĆØ un termine generale che puĆ² significare sia architettura che checkpoint.
+
+
+
+In questo tutorial, imparerai a:
+
+* Caricare un tokenizer pre-allenato.
+* Caricare un estrattore di caratteristiche (feature extractor, in inglese) pre-allenato.
+* Caricare un processore pre-allenato.
+* Caricare un modello pre-allenato.
+
+## AutoTokenizer
+
+Quasi tutti i compiti di NLP iniziano con un tokenizer. Un tokenizer converte il tuo input in un formato che possa essere elaborato dal modello.
+
+Carica un tokenizer con [`AutoTokenizer.from_pretrained`]:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
+```
+
+Poi tokenizza il tuo input come mostrato in seguito:
+
+```py
+>>> sequenza = "In un buco nel terreno viveva uno Hobbit."
+>>> print(tokenizer(sequenza))
+{'input_ids': [0, 360, 51, 373, 587, 1718, 54644, 22597, 330, 3269, 2291, 22155, 18, 5, 2],
+ 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+## AutoFeatureExtractor
+
+Per compiti inerenti a audio e video, un feature extractor processa il segnale audio o l'immagine nel formato di input corretto.
+
+Carica un feature extractor con [`AutoFeatureExtractor.from_pretrained`]:
+
+```py
+>>> from transformers import AutoFeatureExtractor
+
+>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
+... "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
+... )
+```
+
+## AutoProcessor
+
+Compiti multimodali richiedono un processore che combini i due tipi di strumenti di elaborazione. Per esempio, il modello [LayoutLMV2](model_doc/layoutlmv2) richiede un feature extractor per gestire le immagine e un tokenizer per gestire il testo; un processore li combina entrambi.
+
+Carica un processore con [`AutoProcessor.from_pretrained`]:
+
+```py
+>>> from transformers import AutoProcessor
+
+>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+```
+
+## AutoModel
+
+
+
+Infine, le classi `AutoModelFor` ti permettono di caricare un modello pre-allenato per un determinato compito (guarda [qui](model_doc/auto) per una lista completa di compiti presenti). Per esempio, carica un modello per la classificazione di sequenze con [`AutoModelForSequenceClassification.from_pretrained`]:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification
+
+>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
+```
+
+Semplicemente utilizza lo stesso checkpoint per caricare un'architettura per un task differente:
+
+```py
+>>> from transformers import AutoModelForTokenClassification
+
+>>> model = AutoModelForTokenClassification.from_pretrained("distilbert-base-uncased")
+```
+
+Generalmente, raccomandiamo di utilizzare la classe `AutoTokenizer` e la classe `AutoModelFor` per caricare istanze pre-allenate dei modelli. Questo ti assicurerĆ di aver caricato la corretta architettura ogni volta. Nel prossimo [tutorial](preprocessing), imparerai come utilizzare il tokenizer, il feature extractor e il processore per elaborare un dataset per il fine-tuning.
+
+
+
+Infine, le classi `TFAutoModelFor` ti permettono di caricare un modello pre-allenato per un determinato compito (guarda [qui](model_doc/auto) per una lista completa di compiti presenti). Per esempio, carica un modello per la classificazione di sequenze con [`TFAutoModelForSequenceClassification.from_pretrained`]:
+
+```py
+>>> from transformers import TFAutoModelForSequenceClassification
+
+>>> model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
+```
+
+Semplicemente utilizza lo stesso checkpoint per caricare un'architettura per un task differente:
+
+```py
+>>> from transformers import TFAutoModelForTokenClassification
+
+>>> model = TFAutoModelForTokenClassification.from_pretrained("distilbert-base-uncased")
+```
+
+Generalmente, raccomandiamo di utilizzare la classe `AutoTokenizer` e la classe `TFAutoModelFor` per caricare istanze pre-allenate dei modelli. Questo ti assicurerĆ di aver caricato la corretta architettura ogni volta. Nel prossimo [tutorial](preprocessing), imparerai come utilizzare il tokenizer, il feature extractor e il processore per elaborare un dataset per il fine-tuning.
+
+
diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx
new file mode 100644
index 00000000000000..d5e10b7c4983cb
--- /dev/null
+++ b/docs/source/it/index.mdx
@@ -0,0 +1,291 @@
+
+
+# š¤ Transformers
+
+Machine Learning allo stato dell'arte per PyTorch, TensorFlow e JAX.
+
+š¤ Transformers fornisce delle API per scaricare in modo semplice e allenare modelli pre-allenati allo stato dell'arte. L'utilizzo di modelli pre-allenati puĆ² ridurre i tuoi costi computazionali, l'impatto ambientale, e farti risparmiare il tempo che utilizzeresti per allenare un modello da zero. I modelli possono essere utilizzati in diverse modalitĆ come ad esempio:
+
+* š Testo: classificazione del testo, estrazione delle informazioni, rispondere a domande, riassumere, traduzione e generazione del testo in piĆ¹ di 100 lingue.
+* š¼ļø Immagini: classificazione di immagini, rilevazione di oggetti e segmentazione.
+* š£ļø Audio: riconoscimento vocale e classificazione dell'audio.
+* š Multimodale: rispondere a domande inerenti dati tabulari, riconoscimento ottico dei caratteri, estrazione di informazioni a partire da documenti scannerizzati, classificazione di video e risposta visuale a domande.
+
+La nostra libreria supporta un'integrazione perfetta tra tre delle librerie per il deep learning piĆ¹ popolari: [PyTorch](https://pytorch.org/), [TensorFlow](https://www.tensorflow.org/) e [JAX](https://jax.readthedocs.io/en/latest/). Allena il tuo modello in tre righe di codice in un framework, e caricalo per l'inferenza in un altro.
+
+Ogni architettura di š¤ Transformers ĆØ definita in un modulo Python indipendente cosƬ da poter essere personalizzata in modo semplice per la ricerca e gli esperimenti.
+
+## Se stai cercando supporto personalizzato dal team di Hugging Face
+
+
+
+
+
+## Contenuti
+
+La documentazione ĆØ organizzata in cinque parti:
+
+- **INIZIARE** contiene un tour rapido e le istruzioni di installazione per cominciare ad utilizzare š¤ Transformers.
+- **TUTORIALS** ĆØ un buon posto da cui iniziare se per te la nostra libreria ĆØ nuova. Questa sezione ti aiuterĆ ad acquisire le competenze basilari di cui hai bisogno per iniziare ad utilizzare š¤ Transformers.
+- **GUIDE PRATICHE** ti mostrerĆ come raggiungere obiettivi specifici come fare fine-tuning di un modello pre-allenato per la modellizzazione del linguaggio o come creare una testa per un modello personalizzato.
+- **GUIDE CONCETTUALI** fornisce discussioni e spiegazioni dei concetti sottostanti alle idee dietro ai modelli, compiti, e la filosofia di progettazione di š¤ Transformers.
+- **API** descrive ogni classe e funzione, raggruppate in:
+ - **CLASSI PRINCIPALI** per le classi principali che espongono le API importanti della libreria.
+ - **MODELLI** per le classi e le funzioni relative ad ogni modello implementato all'interno della libreria.
+ - **HELPERS INTERNI** per le classi e le funzioni che utilizziamo internamente.
+
+La libreria attualmente contiene implementazioni in JAX, PyTorch e TensorFlow, pesi di modelli pre-allenati, script di utilizzo e strumenti di conversione per i seguenti modelli.
+
+### Modelli supportati
+
+
+
+1. **[ALBERT](model_doc/albert)** (da Google Research e l'Istituto Tecnologico di Chicago) rilasciato con il paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), da Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
+1. **[BART](model_doc/bart)** (da Facebook) rilasciato con il paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) da Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov e Luke Zettlemoyer.
+1. **[BARThez](model_doc/barthez)** (da politecnico di Ćcole) rilasciato con il paper [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://arxiv.org/abs/2010.12321) da Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis.
+1. **[BARTpho](model_doc/bartpho)** (da VinAI Research) rilasciato con il paper [BARTpho: Pre-trained Sequence-to-Sequence Models for Vietnamese](https://arxiv.org/abs/2109.09701) da Nguyen Luong Tran, Duong Minh Le e Dat Quoc Nguyen.
+1. **[BEiT](model_doc/beit)** (da Microsoft) rilasciato con il paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) da Hangbo Bao, Li Dong, Furu Wei.
+1. **[BERT](model_doc/bert)** (da Google) rilasciato con il paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) da Jacob Devlin, Ming-Wei Chang, Kenton Lee e Kristina Toutanova.
+1. **[BERTweet](model_doc/bertweet)** (da VinAI Research) rilasciato con il paper [BERTweet: A pre-trained language model for English Tweets](https://aclanthology.org/2020.emnlp-demos.2/) da Dat Quoc Nguyen, Thanh Vu e Anh Tuan Nguyen.
+1. **[BERT For Sequence Generation](model_doc/bert-generation)** (da Google) rilasciato con il paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) da Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
+1. **[BigBird-RoBERTa](model_doc/big_bird)** (da Google Research) rilasciato con il paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) da Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
+1. **[BigBird-Pegasus](model_doc/bigbird_pegasus)** (v Google Research) rilasciato con il paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) da Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
+1. **[Blenderbot](model_doc/blenderbot)** (da Facebook) rilasciato con il paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) da Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BlenderbotSmall](model_doc/blenderbot-small)** (da Facebook) rilasciato con il paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) da Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BORT](model_doc/bort)** (da Alexa) rilasciato con il paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) da Adrian de Wynter e Daniel J. Perry.
+1. **[ByT5](model_doc/byt5)** (da Google Research) rilasciato con il paper [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) da Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel.
+1. **[CamemBERT](model_doc/camembert)** (da Inria/Facebook/Sorbonne) rilasciato con il paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) da Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz SuĆ”rez*, Yoann Dupont, Laurent Romary, Ćric Villemonte de la Clergerie, DjamĆ© Seddah e BenoĆ®t Sagot.
+1. **[CANINE](model_doc/canine)** (da Google Research) rilasciato con il paper [CANINE: Pre-training an Efficient Tokenization-Free Encoder for Language Representation](https://arxiv.org/abs/2103.06874) da Jonathan H. Clark, Dan Garrette, Iulia Turc, John Wieting.
+1. **[ConvNeXT](model_doc/convnext)** (da Facebook AI) rilasciato con il paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) da Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
+1. **[CLIP](model_doc/clip)** (da OpenAI) rilasciato con il paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) da Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
+1. **[ConvBERT](model_doc/convbert)** (da YituTech) rilasciato con il paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) da Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
+1. **[CPM](model_doc/cpm)** (dalla UniversitĆ di Tsinghua) rilasciato con il paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) da Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun.
+1. **[CTRL](model_doc/ctrl)** (da Salesforce) rilasciato con il paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) da Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong e Richard Socher.
+1. **[CvT](model_doc/cvt)** (da Microsoft) rilasciato con il paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) da Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
+1. **[Data2Vec](model_doc/data2vec)** (da Facebook) rilasciato con il paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) da Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
+1. **[DeBERTa](model_doc/deberta)** (da Microsoft) rilasciato con il paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) da Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
+1. **[DeBERTa-v2](model_doc/deberta-v2)** (da Microsoft) rilasciato con il paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) da Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
+1. **[Decision Transformer](model_doc/decision_transformer)** (da Berkeley/Facebook/Google) rilasciato con il paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) da Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
+1. **[DiT](model_doc/dit)** (da Microsoft Research) rilasciato con il paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) da Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
+1. **[DeiT](model_doc/deit)** (da Facebook) rilasciato con il paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) da Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, HervƩ JƩgou.
+1. **[DETR](model_doc/detr)** (da Facebook) rilasciato con il paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) da Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
+1. **[DialoGPT](model_doc/dialogpt)** (da Microsoft Research) rilasciato con il paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) da Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
+1. **[DistilBERT](model_doc/distilbert)** (da HuggingFace), rilasciato assieme al paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) da Victor Sanh, Lysandre Debut e Thomas Wolf. La stessa tecnica ĆØ stata applicata per comprimere GPT2 in [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa in [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT in [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
+1. **[DPR](model_doc/dpr)** (da Facebook) rilasciato con il paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) da Vladimir Karpukhin, Barlas OÄuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, e Wen-tau Yih.
+1. **[DPT](master/model_doc/dpt)** (da Intel Labs) rilasciato con il paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) da RenƩ Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
+1. **[EncoderDecoder](model_doc/encoder-decoder)** (da Google Research) rilasciato con il paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) da Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
+1. **[ELECTRA](model_doc/electra)** (da Google Research/Stanford University) rilasciato con il paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) da Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
+1. **[FlauBERT](model_doc/flaubert)** (da CNRS) rilasciato con il paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) da Hang Le, LoĆÆc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, BenoĆ®t CrabbĆ©, Laurent Besacier, Didier Schwab.
+1. **[FLAVA](model_doc/flava)** (da Facebook AI) rilasciato con il paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) da Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, e Douwe Kiela.
+1. **[FNet](model_doc/fnet)** (da Google Research) rilasciato con il paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) da James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
+1. **[Funnel Transformer](model_doc/funnel)** (da CMU/Google Brain) rilasciato con il paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) da Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
+1. **[GLPN](model_doc/glpn)** (da KAIST) rilasciato con il paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) da Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
+1. **[GPT](model_doc/openai-gpt)** (da OpenAI) rilasciato con il paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) da Alec Radford, Karthik Narasimhan, Tim Salimans e Ilya Sutskever.
+1. **[GPT-2](model_doc/gpt2)** (da OpenAI) rilasciato con il paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) da Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** e Ilya Sutskever**.
+1. **[GPT-J](model_doc/gptj)** (da EleutherAI) rilasciato nel repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) da Ben Wang e Aran Komatsuzaki.
+1. **[GPT Neo](model_doc/gpt_neo)** (da EleutherAI) rilasciato nel repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) da Sid Black, Stella Biderman, Leo Gao, Phil Wang e Connor Leahy.
+1. **[GPT NeoX](model_doc/gpt_neox)** (da EleutherAI) rilasciato con il paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) da Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach
+1. **[Hubert](model_doc/hubert)** (da Facebook) rilasciato con il paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) da Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
+1. **[I-BERT](model_doc/ibert)** (da Berkeley) rilasciato con il paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) da Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer.
+1. **[ImageGPT](model_doc/imagegpt)** (da OpenAI) rilasciato con il paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) da Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
+1. **[LayoutLM](model_doc/layoutlm)** (da Microsoft Research Asia) rilasciato con il paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) da Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
+1. **[LayoutLMv2](model_doc/layoutlmv2)** (da Microsoft Research Asia) rilasciato con il paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) da Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
+1. **[LayoutLMv3](model_doc/layoutlmv3)** (da Microsoft Research Asia) rilasciato con il paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) da Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
+1. **[LayoutXLM](model_doc/layoutlmv2)** (da Microsoft Research Asia) rilasciato con il paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) da Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
+1. **[LED](model_doc/led)** (da AllenAI) rilasciato con il paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) da Iz Beltagy, Matthew E. Peters, Arman Cohan.
+1. **[Longformer](model_doc/longformer)** (da AllenAI) rilasciato con il paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) da Iz Beltagy, Matthew E. Peters, Arman Cohan.
+1. **[LUKE](model_doc/luke)** (da Studio Ousia) rilasciato con il paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) da Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
+1. **[mLUKE](model_doc/mluke)** (da Studio Ousia) rilasciato con il paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) da Ryokan Ri, Ikuya Yamada, e Yoshimasa Tsuruoka.
+1. **[LXMERT](model_doc/lxmert)** (da UNC Chapel Hill) rilasciato con il paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) da Hao Tan e Mohit Bansal.
+1. **[M2M100](model_doc/m2m_100)** (da Facebook) rilasciato con il paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) da Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
+1. **[MarianMT](model_doc/marian)** Modello di machine learning per le traduzioni allenato utilizzando i dati [OPUS](http://opus.nlpl.eu/) di Jƶrg Tiedemann. Il [Framework Marian](https://marian-nmt.github.io/) ĆØ stato sviluppato dal Microsoft Translator Team.
+1. **[MaskFormer](model_doc/maskformer)** (da Meta and UIUC) rilasciato con il paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) da Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
+1. **[MBart](model_doc/mbart)** (da Facebook) rilasciato con il paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) da Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
+1. **[MBart-50](model_doc/mbart)** (da Facebook) rilasciato con il paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) da Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
+1. **[Megatron-BERT](model_doc/megatron-bert)** (da NVIDIA) rilasciato con il paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) da Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper e Bryan Catanzaro.
+1. **[Megatron-GPT2](model_doc/megatron_gpt2)** (da NVIDIA) rilasciato con il paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) da Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper e Bryan Catanzaro.
+1. **[MPNet](model_doc/mpnet)** (da Microsoft Research) rilasciato con il paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) da Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
+1. **[MT5](model_doc/mt5)** (da Google AI) rilasciato con il paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) da Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
+1. **[Nystrƶmformer](model_doc/nystromformer)** (dalla UniversitĆ del Wisconsin - Madison) rilasciato con il paper [Nystrƶmformer: A Nystrƶm-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) da Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
+1. **[OPT](master/model_doc/opt)** (da Meta AI) rilasciato con il paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) da Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
+1. **[Pegasus](model_doc/pegasus)** (da Google) rilasciato con il paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) da Jingqing Zhang, Yao Zhao, Mohammad Saleh e Peter J. Liu.
+1. **[Perceiver IO](model_doc/perceiver)** (da Deepmind) rilasciato con il paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) da Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier HĆ©naff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, JoĆ£o Carreira.
+1. **[PhoBERT](model_doc/phobert)** (da VinAI Research) rilasciato con il paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) da Dat Quoc Nguyen e Anh Tuan Nguyen.
+1. **[PLBart](model_doc/plbart)** (da UCLA NLP) rilasciato con il paper [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) da Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang.
+1. **[PoolFormer](model_doc/poolformer)** (da Sea AI Labs) rilasciato con il paper [MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418) da Yu, Weihao e Luo, Mi e Zhou, Pan e Si, Chenyang e Zhou, Yichen e Wang, Xinchao e Feng, Jiashi e Yan, Shuicheng.
+1. **[ProphetNet](model_doc/prophetnet)** (da Microsoft Research) rilasciato con il paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) da Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang e Ming Zhou.
+1. **[QDQBert](model_doc/qdqbert)** (da NVIDIA) rilasciato con il paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) da Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev e Paulius Micikevicius.
+1. **[REALM](model_doc/realm.html)** (da Google Research) rilasciato con il paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) da Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat e Ming-Wei Chang.
+1. **[Reformer](model_doc/reformer)** (da Google Research) rilasciato con il paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) da Nikita Kitaev, Åukasz Kaiser, Anselm Levskaya.
+1. **[RemBERT](model_doc/rembert)** (da Google Research) rilasciato con il paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) da Hyung Won Chung, Thibault FĆ©vry, Henry Tsai, M. Johnson, Sebastian Ruder.
+1. **[RegNet](model_doc/regnet)** (da META Platforms) rilasciato con il paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) da Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr DollƔr.
+1. **[ResNet](model_doc/resnet)** (da Microsoft Research) rilasciato con il paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) da Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
+1. **[RoBERTa](model_doc/roberta)** (da Facebook), rilasciato assieme al paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) da Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
+1. **[RoFormer](model_doc/roformer)** (da ZhuiyiTechnology), rilasciato assieme al paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) da Jianlin Su e Yu Lu e Shengfeng Pan e Bo Wen e Yunfeng Liu.
+1. **[SegFormer](model_doc/segformer)** (da NVIDIA) rilasciato con il paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) da Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
+1. **[SEW](model_doc/sew)** (da ASAPP) rilasciato con il paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) da Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
+1. **[SEW-D](model_doc/sew_d)** (da ASAPP) rilasciato con il paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) da Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
+1. **[SpeechToTextTransformer](model_doc/speech_to_text)** (da Facebook), rilasciato assieme al paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) da Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
+1. **[SpeechToTextTransformer2](model_doc/speech_to_text_2)** (da Facebook), rilasciato assieme al paper [Large-Scale Self- and Semi-Supervised Learning for Speech Translation](https://arxiv.org/abs/2104.06678) da Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau.
+1. **[Splinter](model_doc/splinter)** (dalla UniversitĆ di Tel Aviv), rilasciato assieme al paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) da Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
+1. **[SqueezeBert](model_doc/squeezebert)** (da Berkeley) rilasciato con il paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) da Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, e Kurt W. Keutzer.
+1. **[Swin Transformer](model_doc/swin)** (da Microsoft) rilasciato con il paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) da Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
+1. **[T5](model_doc/t5)** (da Google AI) rilasciato con il paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) da Colin Raffel e Noam Shazeer e Adam Roberts e Katherine Lee e Sharan Narang e Michael Matena e Yanqi Zhou e Wei Li e Peter J. Liu.
+1. **[T5v1.1](model_doc/t5v1.1)** (da Google AI) rilasciato nel repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) da Colin Raffel e Noam Shazeer e Adam Roberts e Katherine Lee e Sharan Narang e Michael Matena e Yanqi Zhou e Wei Li e Peter J. Liu.
+1. **[TAPAS](model_doc/tapas)** (da Google AI) rilasciato con il paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) da Jonathan Herzig, PaweÅ Krzysztof Nowak, Thomas MĆ¼ller, Francesco Piccinno e Julian Martin Eisenschlos.
+1. **[TAPEX](model_doc/tapex)** (da Microsoft Research) rilasciato con il paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) da Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
+1. **[Trajectory Transformer](model_doc/trajectory_transformers)** (dall'UniversitĆ della California a Berkeley) rilasciato con il paper [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039) da Michael Janner, Qiyang Li, Sergey Levine
+1. **[Transformer-XL](model_doc/transfo-xl)** (da Google/CMU) rilasciato con il paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) da Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
+1. **[TrOCR](model_doc/trocr)** (da Microsoft), rilasciato assieme al paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) da Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
+1. **[UniSpeech](model_doc/unispeech)** (da Microsoft Research) rilasciato con il paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) da Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
+1. **[UniSpeechSat](model_doc/unispeech-sat)** (da Microsoft Research) rilasciato con il paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) da Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[VAN](model_doc/van)** (dalle UniversitĆ di Tsinghua e Nankai) rilasciato con il paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) da Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
+1. **[ViLT](model_doc/vilt)** (da NAVER AI Lab/Kakao Enterprise/Kakao Brain) rilasciato con il paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) da Wonjae Kim, Bokyung Son, Ildoo Kim.
+1. **[Vision Transformer (ViT)](model_doc/vit)** (da Google AI) rilasciato con il paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) da Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
+1. **[ViTMAE](model_doc/vit_mae)** (da Meta AI) rilasciato con il paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) da Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr DollƔr, Ross Girshick.
+1. **[VisualBERT](model_doc/visual_bert)** (da UCLA NLP) rilasciato con il paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) da Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
+1. **[WavLM](model_doc/wavlm)** (da Microsoft Research) rilasciato con il paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) da Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
+1. **[Wav2Vec2](model_doc/wav2vec2)** (da Facebook AI) rilasciato con il paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) da Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
+1. **[Wav2Vec2Phoneme](model_doc/wav2vec2_phoneme)** (da Facebook AI) rilasciato con il paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) da Qiantong Xu, Alexei Baevski, Michael Auli.
+1. **[XGLM](model_doc/xglm)** (da Facebook AI) rilasciato con il paper [Few-shot Learning with Multilingual Language Models](https://arxiv.org/abs/2112.10668) da Xi Victoria Lin, Todor Mihaylov, Mikel Artetxe, Tianlu Wang, Shuohui Chen, Daniel Simig, Myle Ott, Naman Goyal, Shruti Bhosale, Jingfei Du, Ramakanth Pasunuru, Sam Shleifer, Punit Singh Koura, Vishrav Chaudhary, Brian O'Horo, Jeff Wang, Luke Zettlemoyer, Zornitsa Kozareva, Mona Diab, Veselin Stoyanov, Xian Li.
+1. **[XLM](model_doc/xlm)** (v Facebook) rilasciato assieme al paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) da Guillaume Lample e Alexis Conneau.
+1. **[XLM-ProphetNet](model_doc/xlm-prophetnet)** (da Microsoft Research) rilasciato con il paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) da Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang e Ming Zhou.
+1. **[XLM-RoBERTa](model_doc/xlm-roberta)** (da Facebook AI), rilasciato assieme al paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) da Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco GuzmƔn, Edouard Grave, Myle Ott, Luke Zettlemoyer e Veselin Stoyanov.
+1. **[XLM-RoBERTa-XL](model_doc/xlm-roberta-xl)** (da Facebook AI), rilasciato assieme al paper [Larger-Scale Transformers for Multilingual Masked Language Modeling](https://arxiv.org/abs/2105.00572) da Naman Goyal, Jingfei Du, Myle Ott, Giri Anantharaman, Alexis Conneau.
+1. **[XLNet](model_doc/xlnet)** (da Google/CMU) rilasciato con il paper [āXLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) da Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
+1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (da Facebook AI) rilasciato con il paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) da Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
+1. **[XLS-R](model_doc/xls_r)** (da Facebook AI) rilasciato con il paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) da Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
+1. **[YOLOS](model_doc/yolos)** (dalla UniversitĆ della scienza e tecnologia di Huazhong) rilasciato con il paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) da Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
+1. **[YOSO](model_doc/yoso)** (dall'UniversitĆ del Wisconsin - Madison) rilasciato con il paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) da Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
+
+
+### Framework supportati
+
+La tabella seguente rappresenta il supporto attuale nella libreria per ognuno di questi modelli, si puĆ² identificare se questi hanno un Python
+tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria š¤ Tokenizers, e se hanno supporto in Jax (via Flax), PyTorch, e/o TensorFlow.
+
+
+
+| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+|:---------------------------:|:--------------:|:--------------:|:---------------:|:------------------:|:------------:|
+| ALBERT | ā
| ā
| ā
| ā
| ā
|
+| BART | ā
| ā
| ā
| ā
| ā
|
+| BEiT | ā | ā | ā
| ā | ā
|
+| BERT | ā
| ā
| ā
| ā
| ā
|
+| Bert Generation | ā
| ā | ā
| ā | ā |
+| BigBird | ā
| ā
| ā
| ā | ā
|
+| BigBirdPegasus | ā | ā | ā
| ā | ā |
+| Blenderbot | ā
| ā
| ā
| ā
| ā
|
+| BlenderbotSmall | ā
| ā
| ā
| ā
| ā
|
+| CamemBERT | ā
| ā
| ā
| ā
| ā |
+| Canine | ā
| ā | ā
| ā | ā |
+| CLIP | ā
| ā
| ā
| ā
| ā
|
+| ConvBERT | ā
| ā
| ā
| ā
| ā |
+| ConvNext | ā | ā | ā
| ā
| ā |
+| CTRL | ā
| ā | ā
| ā
| ā |
+| CvT | ā | ā | ā
| ā | ā |
+| Data2VecAudio | ā | ā | ā
| ā | ā |
+| Data2VecText | ā | ā | ā
| ā | ā |
+| Data2VecVision | ā | ā | ā
| ā
| ā |
+| DeBERTa | ā
| ā
| ā
| ā
| ā |
+| DeBERTa-v2 | ā
| ā
| ā
| ā
| ā |
+| Decision Transformer | ā | ā | ā
| ā | ā |
+| DeiT | ā | ā | ā
| ā | ā |
+| DETR | ā | ā | ā
| ā | ā |
+| DistilBERT | ā
| ā
| ā
| ā
| ā
|
+| DPR | ā
| ā
| ā
| ā
| ā |
+| DPT | ā | ā | ā
| ā | ā |
+| ELECTRA | ā
| ā
| ā
| ā
| ā
|
+| Encoder decoder | ā | ā | ā
| ā
| ā
|
+| FairSeq Machine-Translation | ā
| ā | ā
| ā | ā |
+| FlauBERT | ā
| ā | ā
| ā
| ā |
+| Flava | ā | ā | ā
| ā | ā |
+| FNet | ā
| ā
| ā
| ā | ā |
+| Funnel Transformer | ā
| ā
| ā
| ā
| ā |
+| GLPN | ā | ā | ā
| ā | ā |
+| GPT Neo | ā | ā | ā
| ā | ā
|
+| GPT NeoX | ā | ā
| ā
| ā | ā |
+| GPT-J | ā | ā | ā
| ā
| ā
|
+| Hubert | ā | ā | ā
| ā
| ā |
+| I-BERT | ā | ā | ā
| ā | ā |
+| ImageGPT | ā | ā | ā
| ā | ā |
+| LayoutLM | ā
| ā
| ā
| ā
| ā |
+| LayoutLMv2 | ā
| ā
| ā
| ā | ā |
+| LayoutLMv3 | ā
| ā
| ā
| ā | ā |
+| LED | ā
| ā
| ā
| ā
| ā |
+| Longformer | ā
| ā
| ā
| ā
| ā |
+| LUKE | ā
| ā | ā
| ā | ā |
+| LXMERT | ā
| ā
| ā
| ā
| ā |
+| M2M100 | ā
| ā | ā
| ā | ā |
+| Marian | ā
| ā | ā
| ā
| ā
|
+| MaskFormer | ā | ā | ā
| ā | ā |
+| mBART | ā
| ā
| ā
| ā
| ā
|
+| MegatronBert | ā | ā | ā
| ā | ā |
+| MobileBERT | ā
| ā
| ā
| ā
| ā |
+| MPNet | ā
| ā
| ā
| ā
| ā |
+| mT5 | ā
| ā
| ā
| ā
| ā
|
+| Nystromformer | ā | ā | ā
| ā | ā |
+| OpenAI GPT | ā
| ā
| ā
| ā
| ā |
+| OpenAI GPT-2 | ā
| ā
| ā
| ā
| ā
|
+| OPT | ā | ā | ā
| ā | ā |
+| Pegasus | ā
| ā
| ā
| ā
| ā
|
+| Perceiver | ā
| ā | ā
| ā | ā |
+| PLBart | ā
| ā | ā
| ā | ā |
+| PoolFormer | ā | ā | ā
| ā | ā |
+| ProphetNet | ā
| ā | ā
| ā | ā |
+| QDQBert | ā | ā | ā
| ā | ā |
+| RAG | ā
| ā | ā
| ā
| ā |
+| Realm | ā
| ā
| ā
| ā | ā |
+| Reformer | ā
| ā
| ā
| ā | ā |
+| RegNet | ā | ā | ā
| ā | ā |
+| RemBERT | ā
| ā
| ā
| ā
| ā |
+| ResNet | ā | ā | ā
| ā | ā |
+| RetriBERT | ā
| ā
| ā
| ā | ā |
+| RoBERTa | ā
| ā
| ā
| ā
| ā
|
+| RoFormer | ā
| ā
| ā
| ā
| ā
|
+| SegFormer | ā | ā | ā
| ā | ā |
+| SEW | ā | ā | ā
| ā | ā |
+| SEW-D | ā | ā | ā
| ā | ā |
+| Speech Encoder decoder | ā | ā | ā
| ā | ā
|
+| Speech2Text | ā
| ā | ā
| ā
| ā |
+| Speech2Text2 | ā
| ā | ā | ā | ā |
+| Splinter | ā
| ā
| ā
| ā | ā |
+| SqueezeBERT | ā
| ā
| ā
| ā | ā |
+| Swin | ā | ā | ā
| ā
| ā |
+| T5 | ā
| ā
| ā
| ā
| ā
|
+| TAPAS | ā
| ā | ā
| ā
| ā |
+| Trajectory Transformer | ā | ā | ā
| ā | ā |
+| Transformer-XL | ā
| ā | ā
| ā
| ā |
+| TrOCR | ā | ā | ā
| ā | ā |
+| UniSpeech | ā | ā | ā
| ā | ā |
+| UniSpeechSat | ā | ā | ā
| ā | ā |
+| VAN | ā | ā | ā
| ā | ā |
+| ViLT | ā | ā | ā
| ā | ā |
+| Vision Encoder decoder | ā | ā | ā
| ā
| ā
|
+| VisionTextDualEncoder | ā | ā | ā
| ā | ā
|
+| VisualBert | ā | ā | ā
| ā | ā |
+| ViT | ā | ā | ā
| ā
| ā
|
+| ViTMAE | ā | ā | ā
| ā
| ā |
+| Wav2Vec2 | ā
| ā | ā
| ā
| ā
|
+| Wav2Vec2-Conformer | ā | ā | ā
| ā | ā |
+| WavLM | ā | ā | ā
| ā | ā |
+| XGLM | ā
| ā
| ā
| ā | ā
|
+| XLM | ā
| ā | ā
| ā
| ā |
+| XLM-RoBERTa | ā
| ā
| ā
| ā
| ā
|
+| XLM-RoBERTa-XL | ā | ā | ā
| ā | ā |
+| XLMProphetNet | ā
| ā | ā
| ā | ā |
+| XLNet | ā
| ā
| ā
| ā
| ā |
+| YOLOS | ā | ā | ā
| ā | ā |
+| YOSO | ā | ā | ā
| ā | ā |
+
+
\ No newline at end of file
diff --git a/docs/source/it/installation.mdx b/docs/source/it/installation.mdx
new file mode 100644
index 00000000000000..1ff47c110cffad
--- /dev/null
+++ b/docs/source/it/installation.mdx
@@ -0,0 +1,235 @@
+
+
+# Installazione
+
+Installa š¤ Transformers per qualsiasi libreria di deep learning con cui stai lavorando, imposta la tua cache, e opzionalmente configura š¤ Transformers per l'esecuzione offline.
+
+š¤ Transformers ĆØ testato su Python 3.6+, PyTorch 1.1.0+, TensorFlow 2.0+, e Flax. Segui le istruzioni di installazione seguenti per la libreria di deep learning che stai utilizzando:
+
+* [PyTorch](https://pytorch.org/get-started/locally/) istruzioni di installazione.
+* [TensorFlow 2.0](https://www.tensorflow.org/install/pip) istruzioni di installazione.
+* [Flax](https://flax.readthedocs.io/en/latest/) istruzioni di installazione.
+
+## Installazione con pip
+
+Puoi installare š¤ Transformers in un [ambiente virtuale](https://docs.python.org/3/library/venv.html). Se non sei familiare con gli ambienti virtuali in Python, dai un'occhiata a questa [guida](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Un ambiente virtuale rende piĆ¹ semplice la gestione di progetti differenti, evitando problemi di compatibilitĆ tra dipendenze.
+
+Inizia creando un ambiente virtuale nella directory del tuo progetto:
+
+```bash
+python -m venv .env
+```
+
+Attiva l'ambiente virtuale:
+
+```bash
+source .env/bin/activate
+```
+
+Ora puoi procedere con l'installazione di š¤ Transformers eseguendo il comando seguente:
+
+```bash
+pip install transformers
+```
+
+Per il solo supporto della CPU, puoi installare facilmente š¤ Transformers e una libreria di deep learning in solo una riga. Ad esempio, installiamo š¤ Transformers e PyTorch con:
+
+```bash
+pip install transformers[torch]
+```
+
+š¤ Transformers e TensorFlow 2.0:
+
+```bash
+pip install transformers[tf-cpu]
+```
+
+š¤ Transformers e Flax:
+
+```bash
+pip install transformers[flax]
+```
+
+Infine, verifica se š¤ Transformers ĆØ stato installato in modo appropriato eseguendo il seguente comando. Questo scaricherĆ un modello pre-allenato:
+
+```bash
+python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"
+```
+
+DopodichƩ stampa l'etichetta e il punteggio:
+
+```bash
+[{'label': 'POSITIVE', 'score': 0.9998704791069031}]
+```
+
+## Installazione dalla fonte
+
+Installa š¤ Transformers dalla fonte con il seguente comando:
+
+```bash
+pip install git+https://github.com/huggingface/transformers
+```
+
+Questo comando installa la versione `main` piĆ¹ attuale invece dell'ultima versione stabile. Questo ĆØ utile per stare al passo con gli ultimi sviluppi. Ad esempio, se un bug ĆØ stato sistemato da quando ĆØ uscita l'ultima versione ufficiale ma non ĆØ stata ancora rilasciata una nuova versione. Tuttavia, questo significa che questa versione `main` puĆ² non essere sempre stabile. Ci sforziamo per mantenere la versione `main` operativa, e la maggior parte dei problemi viene risolta in poche ore o in un giorno. Se riscontri un problema, per favore apri una [Issue](https://github.com/huggingface/transformers/issues) cosƬ possiamo sistemarlo ancora piĆ¹ velocemente!
+
+Controlla se š¤ Transformers ĆØ stata installata in modo appropriato con il seguente comando:
+
+```bash
+python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('I love you'))"
+```
+
+## Installazione modificabile
+
+Hai bisogno di un'installazione modificabile se vuoi:
+
+* Usare la versione `main` del codice dalla fonte.
+* Contribuire a š¤ Transformers e hai bisogno di testare i cambiamenti nel codice.
+
+Clona il repository e installa š¤ Transformers con i seguenti comandi:
+
+```bash
+git clone https://github.com/huggingface/transformers.git
+cd transformers
+pip install -e .
+```
+
+Questi comandi collegheranno la cartella in cui ĆØ stato clonato il repository e i path delle librerie Python. Python guarderĆ ora all'interno della cartella clonata, oltre ai normali path delle librerie. Per esempio, se i tuoi pacchetti Python sono installati tipicamente in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python cercherĆ anche nella cartella clonata: `~/transformers/`.
+
+
+
+Devi tenere la cartella `transformers` se vuoi continuare ad utilizzare la libreria.
+
+
+
+Ora puoi facilmente aggiornare il tuo clone all'ultima versione di š¤ Transformers con il seguente comando:
+
+```bash
+cd ~/transformers/
+git pull
+```
+
+Il tuo ambiente Python troverĆ la versione `main` di š¤ Transformers alla prossima esecuzione.
+
+## Installazione con conda
+
+Installazione dal canale conda `huggingface`:
+
+```bash
+conda install -c huggingface transformers
+```
+
+## Impostazione della cache
+
+I modelli pre-allenati sono scaricati e memorizzati localmente nella cache in: `~/.cache/huggingface/transformers/`. Questa ĆØ la directory di default data dalla variabile d'ambiente della shell `TRANSFORMERS_CACHE`. Su Windows, la directory di default ĆØ data da `C:\Users\username\.cache\huggingface\transformers`. Puoi cambiare le variabili d'ambiente della shell indicate in seguito, in ordine di prioritĆ , per specificare una directory differente per la cache:
+
+1. Variabile d'ambiente della shell (default): `TRANSFORMERS_CACHE`.
+2. Variabile d'ambiente della shell: `HF_HOME` + `transformers/`.
+3. Variabile d'ambiente della shell: `XDG_CACHE_HOME` + `/huggingface/transformers`.
+
+
+
+š¤ Transformers utilizzerĆ le variabili d'ambiente della shell `PYTORCH_TRANSFORMERS_CACHE` o `PYTORCH_PRETRAINED_BERT_CACHE` se si proviene da un'iterazione precedente di questa libreria e sono state impostate queste variabili d'ambiente, a meno che non si specifichi la variabile d'ambiente della shell `TRANSFORMERS_CACHE`.
+
+
+
+## ModalitĆ Offline
+
+š¤ Transformers puĆ² essere eseguita in un ambiente firewalled o offline utilizzando solo file locali. Imposta la variabile d'ambiente `TRANSFORMERS_OFFLINE=1` per abilitare questo comportamento.
+
+
+
+Aggiungi [š¤ Datasets](https://huggingface.co/docs/datasets/) al tuo flusso di lavoro offline di training impostando la variabile d'ambiente `HF_DATASETS_OFFLINE=1`.
+
+
+
+Ad esempio, in genere si esegue un programma su una rete normale, protetta da firewall per le istanze esterne, con il seguente comando:
+
+```bash
+python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
+```
+
+Esegui lo stesso programma in un'istanza offline con:
+
+```bash
+HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
+python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
+```
+
+Lo script viene ora eseguito senza bloccarsi o attendere il timeout, perchƩ sa di dover cercare solo file locali.
+
+### Ottenere modelli e tokenizer per l'uso offline
+
+Un'altra opzione per utilizzare offline š¤ Transformers ĆØ scaricare i file in anticipo, e poi puntare al loro path locale quando hai la necessitĆ di utilizzarli offline. Ci sono tre modi per fare questo:
+
+* Scarica un file tramite l'interfaccia utente sul [Model Hub](https://huggingface.co/models) premendo sull'icona ā.
+
+ ![download-icon](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/download-icon.png)
+
+* Utilizza il flusso [`PreTrainedModel.from_pretrained`] e [`PreTrainedModel.save_pretrained`]:
+
+ 1. Scarica i tuoi file in anticipo con [`PreTrainedModel.from_pretrained`]:
+
+ ```py
+ >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/T0_3B")
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0_3B")
+ ```
+
+ 2. Salva i tuoi file in una directory specificata con [`PreTrainedModel.save_pretrained`]:
+
+ ```py
+ >>> tokenizer.save_pretrained("./il/tuo/path/bigscience_t0")
+ >>> model.save_pretrained("./il/tuo/path/bigscience_t0")
+ ```
+
+ 3. Ora quando sei offline, carica i tuoi file con [`PreTrainedModel.from_pretrained`] dalla directory specificata:
+
+ ```py
+ >>> tokenizer = AutoTokenizer.from_pretrained("./il/tuo/path/bigscience_t0")
+ >>> model = AutoModel.from_pretrained("./il/tuo/path/bigscience_t0")
+ ```
+
+* Scarica in maniera programmatica i file con la libreria [huggingface_hub](https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub):
+
+ 1. Installa la libreria `huggingface_hub` nel tuo ambiente virtuale:
+
+ ```bash
+ python -m pip install huggingface_hub
+ ```
+
+ 2. Utilizza la funzione [`hf_hub_download`](https://huggingface.co/docs/hub/adding-a-library#download-files-from-the-hub) per scaricare un file in un path specifico. Per esempio, il seguente comando scarica il file `config.json` dal modello [T0](https://huggingface.co/bigscience/T0_3B) nel path che desideri:
+
+ ```py
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> hf_hub_download(repo_id="bigscience/T0_3B", filename="config.json", cache_dir="./il/tuo/path/bigscience_t0")
+ ```
+
+Una volta che il tuo file ĆØ scaricato e salvato in cache localmente, specifica il suo path locale per caricarlo e utilizzarlo:
+
+```py
+>>> from transformers import AutoConfig
+
+>>> config = AutoConfig.from_pretrained("./il/tuo/path/bigscience_t0/config.json")
+```
+
+
+
+Fai riferimento alla sezione [How to download files from the Hub](https://huggingface.co/docs/hub/how-to-downstream) per avere maggiori dettagli su come scaricare modelli presenti sull Hub.
+
+
\ No newline at end of file
diff --git a/docs/source/it/pipeline_tutorial.mdx b/docs/source/it/pipeline_tutorial.mdx
new file mode 100644
index 00000000000000..2fdd0f8158c895
--- /dev/null
+++ b/docs/source/it/pipeline_tutorial.mdx
@@ -0,0 +1,148 @@
+
+
+# Pipeline per l'inferenza
+
+La [`pipeline`] rende semplice usare qualsiasi modello dal [Model Hub](https://huggingface.co/models) per fare inferenza su diversi compiti come generazione del testo, segmentazione di immagini e classificazione di audio. Anche se non hai esperienza con una modalitĆ specifica o non comprendi bene il codice che alimenta i modelli, ĆØ comunque possibile utilizzarli con l'opzione [`pipeline`]! Questa esercitazione ti insegnerĆ a:
+
+* Usare una [`pipeline`] per fare inferenza.
+* Usare uno specifico tokenizer o modello.
+* Usare una [`pipeline`] per compiti che riguardano audio e video.
+
+
+
+Dai un'occhiata alla documentazione di [`pipeline`] per una lista completa dei compiti supportati.
+
+
+
+## Utilizzo della Pipeline
+
+Nonostante ogni compito abbia una [`pipeline`] associata, ĆØ piĆ¹ semplice utilizzare l'astrazione generica della [`pipeline`] che contiene tutte quelle specifiche per ogni mansione. La [`pipeline`] carica automaticamente un modello predefinito e un tokenizer in grado di fare inferenza per il tuo compito.
+
+1. Inizia creando una [`pipeline`] e specificando il compito su cui fare inferenza:
+
+```py
+>>> from transformers import pipeline
+
+>>> generator = pipeline(task="text-generation")
+```
+
+2. Inserisci il testo in input nella [`pipeline`]:
+
+```py
+>>> generator(
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone"
+... ) # doctest: +SKIP
+[{'generated_text': 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Seven for the Iron-priests at the door to the east, and thirteen for the Lord Kings at the end of the mountain'}]
+```
+
+Se hai piĆ¹ di un input, inseriscilo in una lista:
+
+```py
+>>> generator(
+... [
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone",
+... "Nine for Mortal Men, doomed to die, One for the Dark Lord on his dark throne",
+... ]
+... ) # doctest: +SKIP
+```
+
+Qualsiasi parametro addizionale per il tuo compito puĆ² essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation_utils.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare piĆ¹ di un output, utilizza il parametro `num_return_sequences`:
+
+```py
+>>> generator(
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone",
+... num_return_sequences=2,
+... ) # doctest: +SKIP
+```
+
+### Scegliere modello e tokenizer
+
+La [`pipeline`] accetta qualsiasi modello dal [Model Hub](https://huggingface.co/models). Ci sono tag nel Model Hub che consentono di filtrare i modelli per attivitĆ . Una volta che avrai scelto il modello appropriato, caricalo usando la corrispondente classe `AutoModelFor` e [`AutoTokenizer`]. Ad esempio, carica la classe [`AutoModelForCausalLM`] per un compito di causal language modeling:
+
+```py
+>>> from transformers import AutoTokenizer, AutoModelForCausalLM
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
+>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+```
+
+Crea una [`pipeline`] per il tuo compito, specificando il modello e il tokenizer che hai caricato:
+
+```py
+>>> from transformers import pipeline
+
+>>> generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
+```
+
+Inserisci il testo di input nella [`pipeline`] per generare del testo:
+
+```py
+>>> generator(
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone"
+... ) # doctest: +SKIP
+[{'generated_text': 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Seven for the Dragon-lords (for them to rule in a world ruled by their rulers, and all who live within the realm'}]
+```
+
+## Audio pipeline
+
+La flessibilitĆ della [`pipeline`] fa si che possa essere estesa ad attivitĆ sugli audio.
+
+Per esempio, classifichiamo le emozioni in questo clip audio:
+
+```py
+>>> from datasets import load_dataset
+>>> import torch
+
+>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
+>>> ds = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+>>> audio_file = ds[0]["audio"]["path"]
+```
+
+Trova un modello per la [classificazione audio](https://huggingface.co/models?pipeline_tag=audio-classification) sul Model Hub per eseguire un compito di riconoscimento automatico delle emozioni e caricalo nella [`pipeline`]:
+
+```py
+>>> from transformers import pipeline
+
+>>> audio_classifier = pipeline(
+... task="audio-classification", model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
+... )
+```
+
+Inserisci il file audio nella [`pipeline`]:
+
+```py
+>>> preds = audio_classifier(audio_file)
+>>> preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]
+>>> preds
+[{'score': 0.1315, 'label': 'calm'}, {'score': 0.1307, 'label': 'neutral'}, {'score': 0.1274, 'label': 'sad'}, {'score': 0.1261, 'label': 'fearful'}, {'score': 0.1242, 'label': 'happy'}]
+```
+
+## Vision pipeline
+
+Infine, usare la [`pipeline`] per le attivitĆ sulle immagini ĆØ praticamente la stessa cosa.
+
+Specifica la tua attivitĆ e inserisci l'immagine nel classificatore. L'immagine puĆ² essere sia un link che un percorso sul tuo pc in locale. Per esempio, quale specie di gatto ĆØ raffigurata qui sotto?
+
+![pipeline-cat-chonk](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg)
+
+```py
+>>> from transformers import pipeline
+
+>>> vision_classifier = pipeline(task="image-classification")
+>>> preds = vision_classifier(
+... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+... )
+>>> preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]
+>>> preds
+[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]
+```
diff --git a/docs/source/it/quicktour.mdx b/docs/source/it/quicktour.mdx
new file mode 100644
index 00000000000000..4caecac1177030
--- /dev/null
+++ b/docs/source/it/quicktour.mdx
@@ -0,0 +1,393 @@
+
+
+# Quick tour
+
+[[open-in-colab]]
+
+Entra in azione con š¤ Transformers! Inizia utilizzando [`pipeline`] per un'inferenza veloce, carica un modello pre-allenato e un tokenizer con una [AutoClass](./model_doc/auto) per risolvere i tuoi compiti legati a testo, immagini o audio.
+
+
+
+Tutti gli esempi di codice presenti in questa documentazione hanno un pulsante in alto a sinistra che permette di selezionare tra PyTorch e TensorFlow. Se
+questo non ĆØ presente, ci si aspetta che il codice funzioni per entrambi i backend senza alcun cambiamento.
+
+
+
+## Pipeline
+
+[`pipeline`] ĆØ il modo piĆ¹ semplice per utilizzare un modello pre-allenato per un dato compito.
+
+
+
+La [`pipeline`] supporta molti compiti comuni:
+
+**Testo**:
+* Analisi del Sentimento (Sentiment Analysis, in inglese): classifica la polaritĆ di un testo dato.
+* Generazione del Testo (Text Generation, in inglese): genera del testo a partire da un dato input.
+* Riconoscimento di EntitĆ (Name Entity Recognition o NER, in inglese): etichetta ogni parola con l'entitĆ che questa rappresenta (persona, data, luogo, ecc.).
+* Rispondere a Domande (Question answering, in inglese): estrae la risposta da un contesto, dato del contesto e una domanda.
+* Riempimento di Maschere (Fill-mask, in inglese): riempie gli spazi mancanti in un testo che ha parole mascherate.
+* Riassumere (Summarization, in inglese): genera una sintesi di una lunga sequenza di testo o di un documento.
+* Traduzione (Translation, in inglese): traduce un testo in un'altra lingua.
+* Estrazione di Caratteristiche (Feature Extraction, in inglese): crea un tensore che rappresenta un testo.
+
+**Immagini**:
+* Classificazione di Immagini (Image Classification, in inglese): classifica un'immagine.
+* Segmentazione di Immagini (Image Segmentation, in inglese): classifica ogni pixel di un'immagine.
+* Rilevazione di Oggetti (Object Detection, in inglese): rileva oggetti all'interno di un'immagine.
+
+**Audio**:
+* Classificazione di Audio (Audio Classification, in inglese): assegna un'etichetta ad un segmento di audio dato.
+* Riconoscimento Vocale Automatico (Automatic Speech Recognition o ASR, in inglese): trascrive il contenuto di un audio dato in un testo.
+
+
+
+Per maggiori dettagli legati alla [`pipeline`] e ai compiti ad essa associati, fai riferimento alla documentazione [qui](./main_classes/pipelines).
+
+
+
+### Utilizzo della Pipeline
+
+Nel seguente esempio, utilizzerai la [`pipeline`] per l'analisi del sentimento.
+
+Installa le seguenti dipendenze se non lo hai giĆ fatto:
+
+
+
+```bash
+pip install torch
+```
+
+
+```bash
+pip install tensorflow
+```
+
+
+
+Importa [`pipeline`] e specifica il compito che vuoi completare:
+
+```py
+>>> from transformers import pipeline
+
+>>> classificatore = pipeline("sentiment-analysis", model="MilaNLProc/feel-it-italian-sentiment")
+```
+
+La pipeline scarica e salva il [modello pre-allenato](https://huggingface.co/MilaNLProc/feel-it-italian-sentiment) e il tokenizer per l'analisi del sentimento. Se non avessimo scelto un modello, la pipeline ne avrebbe scelto uno di default. Ora puoi utilizzare il `classifier` sul tuo testo obiettivo:
+
+```py
+>>> classificatore("Siamo molto felici di mostrarti la libreria š¤ Transformers.")
+[{'label': 'positive', 'score': 0.9997}]
+```
+
+Per piĆ¹ di una frase, passa una lista di frasi alla [`pipeline`] la quale restituirĆ una lista di dizionari:
+
+```py
+>>> risultati = classificatore(
+... ["Siamo molto felici di mostrarti la libreria š¤ Transformers.", "Speriamo te non la odierai."]
+... )
+>>> for risultato in risultati:
+... print(f"etichetta: {risultato['label']}, con punteggio: {round(risultato['score'], 4)}")
+etichetta: positive, con punteggio: 0.9998
+etichetta: negative, con punteggio: 0.9998
+```
+
+La [`pipeline`] puĆ² anche iterare su un dataset intero. Inizia installando la libreria [š¤ Datasets](https://huggingface.co/docs/datasets/):
+
+```bash
+pip install datasets
+```
+
+Crea una [`pipeline`] con il compito che vuoi risolvere e con il modello che vuoi utilizzare.
+
+```py
+>>> import torch
+>>> from transformers import pipeline
+
+>>> riconoscitore_vocale = pipeline(
+... "automatic-speech-recognition", model="radiogroup-crits/wav2vec2-xls-r-1b-italian-doc4lm-5gram"
+... )
+```
+
+Poi, carica un dataset (vedi š¤ Datasets [Quick Start](https://huggingface.co/docs/datasets/quickstart.html) per maggiori dettagli) sul quale vuoi iterare. Per esempio, carichiamo il dataset [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14):
+
+```py
+>>> from datasets import load_dataset, Audio
+
+>>> dataset = load_dataset("PolyAI/minds14", name="it-IT", split="train") # doctest: +IGNORE_RESULT
+```
+
+Dobbiamo assicurarci che la frequenza di campionamento del set di dati corrisponda alla frequenza di campionamento con cui ĆØ stato addestrato `radiogroup-crits/wav2vec2-xls-r-1b-italian-doc4lm-5gram`.
+
+```py
+>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=riconoscitore_vocale.feature_extractor.sampling_rate))
+```
+
+I file audio vengono caricati automaticamente e ri-campionati quando chiamiamo la colonna "audio".
+Estraiamo i vettori delle forme d'onda grezze delle prime 4 osservazioni e passiamoli come lista alla pipeline:
+
+```py
+>>> risultato = riconoscitore_vocale(dataset[:4]["audio"])
+>>> print([d["text"] for d in risultato])
+['dovrei caricare dei soldi sul mio conto corrente', 'buongiorno e senza vorrei depositare denaro sul mio conto corrente come devo fare per cortesia', 'sƬ salve vorrei depositare del denaro sul mio conto', 'e buon pomeriggio vorrei depositare dei soldi sul mio conto bancario volleo sapere come posso fare se e posso farlo online ed un altro conto o andandoo tramite bancomut']
+```
+
+Per un dataset piĆ¹ grande dove gli input sono di dimensione maggiore (come nel parlato/audio o nella visione), dovrai passare un generatore al posto di una lista che carica tutti gli input in memoria. Guarda la [documentazione della pipeline](./main_classes/pipelines) per maggiori informazioni.
+
+### Utilizzare un altro modello e tokenizer nella pipeline
+
+La [`pipeline`] puĆ² ospitare qualsiasi modello del [Model Hub](https://huggingface.co/models), rendendo semplice l'adattamento della [`pipeline`] per altri casi d'uso. Per esempio, se si vuole un modello capace di trattare testo in francese, usa i tag presenti nel Model Hub in modo da filtrare per ottenere un modello appropriato. Il miglior risultato filtrato restituisce un modello multi-lingua [BERT model](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) fine-tuned per l'analisi del sentimento. Ottimo, utilizziamo questo modello!
+
+```py
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+```
+
+
+
+Usa [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] per caricare il modello pre-allenato e il suo tokenizer associato (maggiori informazioni su una `AutoClass` in seguito):
+
+```py
+>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+>>> model = AutoModelForSequenceClassification.from_pretrained(model_name)
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+```
+
+
+Usa [`TFAutoModelForSequenceClassification`] e [`AutoTokenizer`] per caricare il modello pre-allenato e il suo tokenizer associato (maggiori informazioni su una `TFAutoClass` in seguito):
+
+```py
+>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
+
+>>> model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+```
+
+
+
+Poi puoi specificare il modello e il tokenizer nella [`pipeline`], e applicare il `classifier` sul tuo testo obiettivo:
+
+```py
+>>> classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
+>>> classifier("Nous sommes trĆØs heureux de vous prĆ©senter la bibliothĆØque š¤ Transformers.")
+[{'label': '5 stars', 'score': 0.7273}]
+```
+
+Se non riesci a trovare un modello per il tuo caso d'uso, dovrai fare fine-tuning di un modello pre-allenato sui tuoi dati. Dai un'occhiata al nostro tutorial [fine-tuning tutorial](./training) per imparare come. Infine, dopo che hai completato il fine-tuning del tuo modello pre-allenato, considera per favore di condividerlo (vedi il tutorial [qui](./model_sharing)) con la comunitĆ sul Model Hub per democratizzare l'NLP! š¤
+
+## AutoClass
+
+
+
+Al suo interno, le classi [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] lavorano assieme per dare potere alla [`pipeline`]. Una [AutoClass](./model_doc/auto) ĆØ una scorciatoia che automaticamente recupera l'architettura di un modello pre-allenato a partire dal suo nome o path. Hai solo bisogno di selezionare la `AutoClass` appropriata per il tuo compito e il suo tokenizer associato con [`AutoTokenizer`].
+
+Ritorniamo al nostro esempio e vediamo come puoi utilizzare la `AutoClass` per replicare i risultati della [`pipeline`].
+
+### AutoTokenizer
+
+Un tokenizer ĆØ responsabile dell'elaborazione del testo in modo da trasformarlo in un formato comprensibile dal modello. Per prima cosa, il tokenizer dividerĆ il testo in parole chiamate *token*. Ci sono diverse regole che governano il processo di tokenizzazione, tra cui come dividere una parola e a quale livello (impara di piĆ¹ sulla tokenizzazione [qui](./tokenizer_summary)). La cosa piĆ¹ importante da ricordare comunque ĆØ che hai bisogno di inizializzare il tokenizer con lo stesso nome del modello in modo da assicurarti che stai utilizzando le stesse regole di tokenizzazione con cui il modello ĆØ stato pre-allenato.
+
+Carica un tokenizer con [`AutoTokenizer`]:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> nome_del_modello = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> tokenizer = AutoTokenizer.from_pretrained(nome_del_modello)
+```
+
+DopodichĆ©, il tokenizer converte i token in numeri in modo da costruire un tensore come input del modello. Questo ĆØ conosciuto come il *vocabolario* del modello.
+
+Passa il tuo testo al tokenizer:
+
+```py
+>>> encoding = tokenizer("Siamo molto felici di mostrarti la libreria š¤ Transformers.")
+>>> print(encoding)
+{'input_ids': [101, 56821, 10132, 14407, 13019, 13007, 10120, 47201, 10330, 10106, 91686, 100, 58263, 119, 102],
+'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+Il tokenizer restituirĆ un dizionario contenente:
+
+* [input_ids](./glossary#input-ids): rappresentazioni numeriche dei tuoi token.
+* [attention_mask](.glossary#attention-mask): indica quali token devono essere presi in considerazione.
+
+Come con la [`pipeline`], il tokenizer accetterĆ una lista di input. In piĆ¹, il tokenizer puĆ² anche completare (pad, in inglese) e troncare il testo in modo da restituire un lotto (batch, in inglese) di lunghezza uniforme:
+
+
+
+```py
+>>> pt_batch = tokenizer(
+... ["Siamo molto felici di mostrarti la libreria š¤ Transformers.", "Speriamo te non la odierai."],
+... padding=True,
+... truncation=True,
+... max_length=512,
+... return_tensors="pt",
+... )
+```
+
+
+```py
+>>> tf_batch = tokenizer(
+... ["Siamo molto felici di mostrarti la libreria š¤ Transformers.", "Speriamo te non la odierai."],
+... padding=True,
+... truncation=True,
+... max_length=512,
+... return_tensors="tf",
+... )
+```
+
+
+
+Leggi il tutorial sul [preproccesing](./preprocessing) per maggiori dettagli sulla tokenizzazione.
+
+### AutoModel
+
+
+
+š¤ Transformers fornisce un metodo semplice e unificato per caricare istanze pre-allenate. Questo significa che puoi caricare un [`AutoModel`] come caricheresti un [`AutoTokenizer`]. L'unica differenza ĆØ selezionare l'[`AutoModel`] corretto per il compito di interesse. Dato che stai facendo classificazione di testi, o sequenze, carica [`AutoModelForSequenceClassification`]:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification
+
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(model_name)
+```
+
+
+
+Guarda il [task summary](./task_summary) per sapere quale classe di [`AutoModel`] utilizzare per quale compito.
+
+
+
+Ora puoi passare il tuo lotto di input pre-processati direttamente al modello. Devi solo spacchettare il dizionario aggiungendo `**`:
+
+```py
+>>> pt_outputs = pt_model(**pt_batch)
+```
+
+Il modello produrrĆ le attivazioni finali nell'attributo `logits`. Applica la funzione softmax a `logits` per ottenere le probabilitĆ :
+
+```py
+>>> from torch import nn
+
+>>> pt_predictions = nn.functional.softmax(pt_outputs.logits, dim=-1)
+>>> print(pt_predictions)
+tensor([[0.0041, 0.0037, 0.0203, 0.2005, 0.7713],
+ [0.3766, 0.3292, 0.1832, 0.0558, 0.0552]], grad_fn=)
+```
+
+
+š¤ Transformers fornisce un metodo semplice e unificato per caricare istanze pre-allenate. Questo significa che puoi caricare un [`TFAutoModel`] come caricheresti un [`AutoTokenizer`]. L'unica differenza ĆØ selezionare il [`TFAutoModel`] corretto per il compito di interesse. Dato che stai facendo classificazione di testi, o sequenze, carica [`TFAutoModelForSequenceClassification`]:
+
+```py
+>>> from transformers import TFAutoModelForSequenceClassification
+
+>>> nome_del_modello = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(nome_del_modello)
+```
+
+
+
+Guarda il [task summary](./task_summary) per sapere quale classe di [`AutoModel`] utilizzare per quale compito.
+
+
+
+Ora puoi passare il tuo lotto di input pre-processati direttamente al modello passando le chiavi del dizionario al tensore:
+
+```py
+>>> tf_outputs = tf_model(tf_batch)
+```
+
+Il modello produrrĆ le attivazioni finali nell'attributo `logits`. Applica la funzione softmax a `logits` per ottenere le probabilitĆ :
+```py
+>>> import tensorflow as tf
+
+>>> tf_predictions = tf.nn.softmax(tf_outputs.logits, axis=-1)
+>>> tf_predictions # doctest: +IGNORE_RESULT
+```
+
+
+
+
+
+Tutti i modelli di š¤ Transformers (PyTorch e TensorFlow) restituiscono i tensori *prima* della funzione finale
+di attivazione (come la softmax) perchƩ la funzione di attivazione finale viene spesso unita a quella di perdita.
+
+
+
+I modelli sono [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) o [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) standard cosƬ puoi utilizzarli all'interno del tuo training loop usuale. Tuttavia, per rendere le cose piĆ¹ semplici, š¤ Transformers fornisce una classe [`Trainer`] per PyTorch che aggiunge delle funzionalitĆ per l'allenamento distribuito, precisione mista, e altro ancora. Per TensorFlow, puoi utilizzare il metodo `fit` di [Keras](https://keras.io/). Fai riferimento al [tutorial per il training](./training) per maggiori dettagli.
+
+
+
+Gli output del modello di š¤ Transformers sono delle dataclasses speciali in modo che i loro attributi vengano auto-completati all'interno di un IDE.
+Gli output del modello si comportano anche come una tupla o un dizionario (ad esempio, puoi indicizzare con un intero, una slice o una stringa) nel qual caso gli attributi che sono `None` vengono ignorati.
+
+
+
+### Salva un modello
+
+
+
+Una volta completato il fine-tuning del tuo modello, puoi salvarlo con il suo tokenizer utilizzando [`PreTrainedModel.save_pretrained`]:
+
+```py
+>>> pt_save_directory = "./pt_save_pretrained"
+>>> tokenizer.save_pretrained(pt_save_directory) # doctest: +IGNORE_RESULT
+>>> pt_model.save_pretrained(pt_save_directory)
+```
+
+Quando desideri utilizzare il tuo modello nuovamente, puoi ri-caricarlo con [`PreTrainedModel.from_pretrained`]:
+
+```py
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained("./pt_save_pretrained")
+```
+
+
+Una volta completato il fine-tuning del tuo modello, puoi salvarlo con il suo tokenizer utilizzando [`TFPreTrainedModel.save_pretrained`]:
+
+```py
+>>> tf_save_directory = "./tf_save_pretrained"
+>>> tokenizer.save_pretrained(tf_save_directory) # doctest: +IGNORE_RESULT
+>>> tf_model.save_pretrained(tf_save_directory)
+```
+
+Quando desideri utilizzare il tuo modello nuovamente, puoi ri-caricarlo con [`TFPreTrainedModel.from_pretrained`]:
+
+```py
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained("./tf_save_pretrained")
+```
+
+
+
+Una caratteristica particolarmente interessante di š¤ Transformers ĆØ la sua abilitĆ di salvare un modello e ri-caricarlo sia come modello di PyTorch che di TensorFlow. I parametri `from_pt` o `from_tf` possono convertire un modello da un framework all'altro:
+
+
+
+```py
+>>> from transformers import AutoModel
+
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+```
+
+
+```py
+>>> from transformers import TFAutoModel
+
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+```
+
+
diff --git a/docs/source/pt/_config.py b/docs/source/pt/_config.py
new file mode 100644
index 00000000000000..cd76263e9a5cb2
--- /dev/null
+++ b/docs/source/pt/_config.py
@@ -0,0 +1,14 @@
+# docstyle-ignore
+INSTALL_CONTENT = """
+# Transformers installation
+! pip install transformers datasets
+# To install from source instead of the last release, comment the command above and uncomment the following one.
+# ! pip install git+https://github.com/huggingface/transformers.git
+"""
+
+notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
+black_avoid_patterns = {
+ "{processor_class}": "FakeProcessorClass",
+ "{model_class}": "FakeModelClass",
+ "{object_class}": "FakeObjectClass",
+}
diff --git a/docs/source/pt/_toctree.yml b/docs/source/pt/_toctree.yml
new file mode 100644
index 00000000000000..a144ff00930ffd
--- /dev/null
+++ b/docs/source/pt/_toctree.yml
@@ -0,0 +1,26 @@
+- sections:
+ - local: quicktour
+ title: Tour rƔpido
+ - local: installation
+ title: InstalaĆ§Ć£o
+ title: Iniciar
+- sections:
+ - local: pipeline_tutorial
+ title: Pipelines para inferĆŖncia
+ - local: training
+ title: Fine-tuning de um modelo prƩ-treinado
+ - local: accelerate
+ title: Treinamento distribuĆdo com š¤ Accelerate
+ title: Tutoriais
+- sections:
+ - local: fast_tokenizers
+ title: Usando os Tokenizers do š¤ Tokenizers
+ - sections:
+ - local: tasks/sequence_classification
+ title: ClassificaĆ§Ć£o de texto
+ - local: tasks/token_classification
+ title: ClassificaĆ§Ć£o de tokens
+ title: Fine-tuning para tarefas especĆficas
+ - local: multilingual
+ title: Modelos multilinguĆsticos para inferĆŖncia
+ title: Guias prƔticos
diff --git a/docs/source/pt/accelerate.mdx b/docs/source/pt/accelerate.mdx
new file mode 100644
index 00000000000000..0e2257faceff84
--- /dev/null
+++ b/docs/source/pt/accelerate.mdx
@@ -0,0 +1,141 @@
+
+
+# Treinamento distribuĆdo com o š¤ Accelerate
+
+O paralelismo surgiu como uma estratƩgia para treinar modelos grandes em hardware limitado e aumentar a velocidade
+de treinamento em vĆ”rias Ć³rdens de magnitude. Na Hugging Face criamos a biblioteca [š¤ Accelerate](https://huggingface.co/docs/accelerate/index.html)
+para ajudar os usuĆ”rios a treinar modelos š¤ Transformers com qualquer configuraĆ§Ć£o distribuĆda, seja em uma mĆ”quina
+com mĆŗltiplos GPUs ou em mĆŗltiplos GPUs distribuidos entre muitas mĆ”quinas. Neste tutorial, vocĆŖ irĆ” aprender como
+personalizar seu laƧo de treinamento de PyTorch para poder treinar em ambientes distribuĆdos.
+
+## ConfiguraĆ§Ć£o
+
+De inĆcio, instale o š¤ Accelerate:
+
+```bash
+pip install accelerate
+```
+
+Logo, devemos importar e criar um objeto [`Accelerator`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator).
+O `Accelerator` detectarĆ” automĆ”ticamente a configuraĆ§Ć£o distribuĆda disponĆvel e inicializarĆ” todos os
+componentes necessĆ”rios para o treinamento. NĆ£o hĆ” necessidade portanto de especificar o dispositivo onde deve colocar seu modelo.
+
+```py
+>>> from accelerate import Accelerator
+
+>>> accelerator = Accelerator()
+```
+
+## Preparando a aceleraĆ§Ć£o
+
+Passe todos os objetos relevantes ao treinamento para o mƩtodo [`prepare`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.prepare).
+Isto inclui os DataLoaders de treino e evaluaĆ§Ć£o, um modelo e um otimizador:
+
+```py
+>>> train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(
+... train_dataloader, eval_dataloader, model, optimizer
+... )
+```
+
+## Backward
+
+Por Ćŗltimo, substitua o `loss.backward()` padrĆ£o em seu laƧo de treinamento com o mĆ©todo [`backward`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.backward) do š¤ Accelerate:
+
+```py
+>>> for epoch in range(num_epochs):
+... for batch in train_dataloader:
+... outputs = model(**batch)
+... loss = outputs.loss
+... accelerator.backward(loss)
+
+... optimizer.step()
+... lr_scheduler.step()
+... optimizer.zero_grad()
+... progress_bar.update(1)
+```
+
+Como se poder ver no seguinte cĆ³digo, sĆ³ precisarĆ” adicionar quatro linhas de cĆ³digo ao seu laƧo de treinamento
+para habilitar o treinamento distribuĆdo!
+
+```diff
++ from accelerate import Accelerator
+ from transformers import AdamW, AutoModelForSequenceClassification, get_scheduler
+
++ accelerator = Accelerator()
+
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
+ optimizer = AdamW(model.parameters(), lr=3e-5)
+
+- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+- model.to(device)
+
++ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(
++ train_dataloader, eval_dataloader, model, optimizer
++ )
+
+ num_epochs = 3
+ num_training_steps = num_epochs * len(train_dataloader)
+ lr_scheduler = get_scheduler(
+ "linear",
+ optimizer=optimizer,
+ num_warmup_steps=0,
+ num_training_steps=num_training_steps
+ )
+
+ progress_bar = tqdm(range(num_training_steps))
+
+ model.train()
+ for epoch in range(num_epochs):
+ for batch in train_dataloader:
+- batch = {k: v.to(device) for k, v in batch.items()}
+ outputs = model(**batch)
+ loss = outputs.loss
+- loss.backward()
++ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ progress_bar.update(1)
+```
+
+## Treinamento
+
+Quando tiver adicionado as linhas de cĆ³digo relevantes, inicie o treinamento por um script ou notebook como o Colab.
+
+### Treinamento em um Script
+
+Se estiver rodando seu treinamento em um Script, execute o seguinte comando para criar e guardar um arquivo de configuraĆ§Ć£o:
+
+```bash
+accelerate config
+```
+
+Comece o treinamento com:
+
+```bash
+accelerate launch train.py
+```
+
+### Treinamento em um Notebook
+
+O š¤ Accelerate pode rodar em um notebook, por exemplo, se estiver planejando usar as TPUs do Google Colab.
+Encapsule o cĆ³digo responsĆ”vel pelo treinamento de uma funĆ§Ć£o e passe-o ao `notebook_launcher`:
+
+```py
+>>> from accelerate import notebook_launcher
+
+>>> notebook_launcher(training_function)
+```
+
+Para obter mais informaƧƵes sobre o š¤ Accelerate e suas numerosas funƧƵes, consulte a [documentaciĆ³n](https://huggingface.co/docs/accelerate/index.html).
diff --git a/docs/source/pt/fast_tokenizers.mdx b/docs/source/pt/fast_tokenizers.mdx
new file mode 100644
index 00000000000000..aff9afb31f2bb3
--- /dev/null
+++ b/docs/source/pt/fast_tokenizers.mdx
@@ -0,0 +1,62 @@
+
+
+# Usando os Tokenizers do š¤ Tokenizers
+
+O [`PreTrainedTokenizerFast`] depende da biblioteca [š¤ Tokenizers](https://huggingface.co/docs/tokenizers). O Tokenizer obtido da biblioteca š¤ Tokenizers pode ser carregado facilmente pelo š¤ Transformers.
+
+Antes de entrar nos detalhes, vamos comeƧar criando um tokenizer fictĆcio em algumas linhas:
+
+```python
+>>> from tokenizers import Tokenizer
+>>> from tokenizers.models import BPE
+>>> from tokenizers.trainers import BpeTrainer
+>>> from tokenizers.pre_tokenizers import Whitespace
+
+>>> tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
+>>> trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
+
+>>> tokenizer.pre_tokenizer = Whitespace()
+>>> files = [...]
+>>> tokenizer.train(files, trainer)
+```
+
+Agora temos um tokenizer treinado nos arquivos que foram definidos. NĆ³s podemos continuar usando nessa execuĆ§Ć£o ou salvar em um arquivo JSON para re-utilizar no futuro.
+
+## Carregando diretamente de um objeto tokenizer
+
+Vamos ver como aproveitar esse objeto tokenizer na biblioteca š¤ Transformers. A classe [`PreTrainedTokenizerFast`] permite uma instanciaĆ§Ć£o fĆ”cil, aceitando o objeto *tokenizer* instanciado como um argumento:
+
+```python
+>>> from transformers import PreTrainedTokenizerFast
+
+>>> fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
+```
+Esse objeto pode ser utilizado com todos os mĆ©todos compartilhados pelos tokenizers dos š¤ Transformers! VĆ” para [a pĆ”gina do tokenizer](main_classes/tokenizer) para mais informaƧƵes.
+
+## Carregando de um arquivo JSON
+
+Para carregar um tokenizer de um arquivo JSON vamos primeiro comeƧar salvando nosso tokenizer:
+
+```python
+>>> tokenizer.save("tokenizer.json")
+```
+
+A pasta para qual salvamos esse arquivo pode ser passada para o mĆ©todo de inicializaĆ§Ć£o do [`PreTrainedTokenizerFast`] usando o `tokenizer_file` parĆ¢metro:
+
+```python
+>>> from transformers import PreTrainedTokenizerFast
+
+>>> fast_tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
+```
+
+Esse objeto pode ser utilizado com todos os mĆ©todos compartilhados pelos tokenizers dos š¤ Transformers! VĆ” para [a pĆ”gina do tokenizer](main_classes/tokenizer) para mais informaƧƵes.
\ No newline at end of file
diff --git a/docs/source/pt/installation.mdx b/docs/source/pt/installation.mdx
new file mode 100644
index 00000000000000..2325cc74afe2d9
--- /dev/null
+++ b/docs/source/pt/installation.mdx
@@ -0,0 +1,258 @@
+
+
+# Guia de InstalaĆ§Ć£o
+
+Neste guia poderĆ” encontrar informaƧƵes para a instalaĆ§Ć£o do š¤ Transformers para qualquer biblioteca de
+Machine Learning com a qual esteja a trabalhar. AlĆ©m disso, poderĆ” encontrar informaƧƵes sobre como gerar cachĆŖs e
+configurar o š¤ Transformers para execuĆ§Ć£o em modo offline (opcional).
+
+š¤ Transformers foi testado com Python 3.6+, PyTorch 1.1.0+, TensorFlow 2.0+, e Flax. Para instalar a biblioteca de
+deep learning com que deseja trabalhar, siga as instruƧƵes correspondentes listadas a seguir:
+
+* [PyTorch](https://pytorch.org/get-started/locally/)
+* [TensorFlow 2.0](https://www.tensorflow.org/install/pip)
+* [Flax](https://flax.readthedocs.io/en/latest/)
+
+## InstalaĆ§Ć£o pelo Pip
+
+Ć sugerido instalar o š¤ Transformers num [ambiente virtual](https://docs.python.org/3/library/venv.html). Se precisar
+de mais informaƧƵes sobre ambientes virtuais em Python, consulte este [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
+Um ambiente virtual facilitarĆ” a manipulaĆ§Ć£o e organizaĆ§Ć£o de projetos e evita problemas de compatibilidade entre dependĆŖncias.
+
+Comece criando um ambiente virtual no diretĆ³rio do seu projeto:
+
+```bash
+python -m venv .env
+```
+
+E para ativar o ambiente virtual:
+
+```bash
+source .env/bin/activate
+```
+
+Agora Ć possĆvel instalar o š¤ Transformers com o comando a seguir:
+
+```bash
+pip install transformers
+```
+
+Somente para a CPU, Ć© possĆvel instalar o š¤ Transformers e a biblioteca de deep learning respectiva apenas numa linha.
+
+Por exemplo, para instalar o š¤ Transformers e o PyTorch, digite:
+
+```bash
+pip install transformers[torch]
+```
+
+š¤ Transformers e TensorFlow 2.0:
+
+```bash
+pip install transformers[tf-cpu]
+```
+
+š¤ Transformers e Flax:
+
+```bash
+pip install transformers[flax]
+```
+
+Por Ćŗltimo, verifique se o š¤ Transformers foi instalado com sucesso usando o seguinte comando para baixar um modelo prĆ©-treinado:
+
+```bash
+python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"
+```
+
+Em seguida, imprima um rĆ³tulo e sua pontuaĆ§Ć£o:
+
+```bash
+[{'label': 'POSITIVE', 'score': 0.9998704791069031}]
+```
+
+## InstalaĆ§Ć£o usando a fonte
+
+Para instalar o š¤ Transformers a partir da fonte use o seguinte comando:
+
+```bash
+pip install git+https://github.com/huggingface/transformers
+```
+
+O comando acima instalarĆ” a versĆ£o `master` mais atual em vez da Ćŗltima versĆ£o estĆ”vel. A versĆ£o `master` Ć© Ćŗtil para
+utilizar os Ćŗltimos updates contidos em š¤ Transformers. Por exemplo, um erro recente pode ter sido corrigido somente
+apĆ³s a Ćŗltima versĆ£o estĆ”vel, antes que houvesse um novo lanƧamento. No entanto, hĆ” a possibilidade que a versĆ£o `master` nĆ£o esteja estĆ”vel.
+A equipa trata de mantĆ©r a versĆ£o `master` operacional e a maioria dos erros sĆ£o resolvidos em poucas horas ou dias.
+Se encontrar quaisquer problemas, por favor abra um [Issue](https://github.com/huggingface/transformers/issues) para que o
+mesmo possa ser corrigido o mais rĆ”pido possĆvel.
+
+Verifique que o š¤ Transformers estĆ” instalado corretamente usando o seguinte comando:
+
+```bash
+python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('I love you'))"
+```
+
+## InstalaĆ§Ć£o editĆ”vel
+
+Uma instalaĆ§Ć£o editĆ”vel serĆ” necessĆ”ria caso desejas um dos seguintes:
+* Usar a versĆ£o `master` do cĆ³digo fonte.
+* Contribuir ao š¤ Transformers e precisa testar mudanƧas ao cĆ³digo.
+
+Para tal, clone o repositĆ³rio e instale o š¤ Transformers com os seguintes comandos:
+
+```bash
+git clone https://github.com/huggingface/transformers.git
+cd transformers
+pip install -e .
+```
+
+Estes comandos vĆ£o ligar o diretĆ³rio para o qual foi clonado o repositĆ³rio ao caminho de bibliotecas do Python.
+O Python agora buscarƔ dentro dos arquivos que foram clonados alƩm dos caminhos normais da biblioteca.
+Por exemplo, se os pacotes do Python se encontram instalados no caminho `~/anaconda3/envs/main/lib/python3.7/site-packages/`,
+o Python tambĆ©m buscarĆ” mĆ³dulos no diretĆ³rio onde clonamos o repositĆ³rio `~/transformers/`.
+
+
+
+Ć necessĆ”rio manter o diretĆ³rio `transformers` se desejas continuar usando a biblioteca.
+
+
+
+Assim, Ć possĆvel atualizar sua cĆ³pia local para com a Ćŗltima versĆ£o do š¤ Transformers com o seguinte comando:
+
+```bash
+cd ~/transformers/
+git pull
+```
+
+O ambiente de Python que foi criado para a instalaĆ§Ć£o do š¤ Transformers encontrarĆ” a versĆ£o `master` em execuƧƵes seguintes.
+
+## InstalaĆ§Ć£o usando o Conda
+
+Ć possĆvel instalar o š¤ Transformers a partir do canal conda `huggingface` com o seguinte comando:
+
+```bash
+conda install -c huggingface transformers
+```
+
+## ConfiguraĆ§Ć£o do CachĆŖ
+
+Os modelos prĆ©-treinados sĆ£o baixados e armazenados no cachĆŖ local, encontrado em `~/.cache/huggingface/transformers/`.
+Este Ć© o diretĆ³rio padrĆ£o determinado pela variĆ”vel `TRANSFORMERS_CACHE` dentro do shell.
+No Windows, este diretĆ³rio prĆ©-definido Ć© dado por `C:\Users\username\.cache\huggingface\transformers`.
+Ć possĆvel mudar as variĆ”veis dentro do shell em ordem de prioridade para especificar um diretĆ³rio de cachĆŖ diferente:
+
+1. VariĆ”vel de ambiente do shell (por padrĆ£o): `TRANSFORMERS_CACHE`.
+2. VariƔvel de ambiente do shell:`HF_HOME` + `transformers/`.
+3. VariƔvel de ambiente do shell: `XDG_CACHE_HOME` + `/huggingface/transformers`.
+
+
+
+ O š¤ Transformers usarĆ” as variĆ”veis de ambiente do shell `PYTORCH_TRANSFORMERS_CACHE` ou `PYTORCH_PRETRAINED_BERT_CACHE`
+ se estiver vindo de uma versĆ£o anterior da biblioteca que tenha configurado essas variĆ”veis de ambiente, a menos que
+ vocĆŖ especifique a variĆ”vel de ambiente do shell `TRANSFORMERS_CACHE`.
+
+
+
+
+## Modo Offline
+
+O š¤ Transformers tambĆ©m pode ser executado num ambiente de firewall ou fora da rede (offline) usando arquivos locais.
+Para tal, configure a variƔvel de ambiente de modo que `TRANSFORMERS_OFFLINE=1`.
+
+
+
+VocĆŖ pode adicionar o [š¤ Datasets](https://huggingface.co/docs/datasets/) ao pipeline de treinamento offline declarando
+ a variƔvel de ambiente `HF_DATASETS_OFFLINE=1`.
+
+
+
+Segue um exemplo de execuĆ§Ć£o do programa numa rede padrĆ£o com firewall para instĆ¢ncias externas, usando o seguinte comando:
+
+```bash
+python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
+```
+
+Execute esse mesmo programa numa instĆ¢ncia offline com o seguinte comando:
+
+```bash
+HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
+python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --dataset_name wmt16 --dataset_config ro-en ...
+```
+
+O script agora deve ser executado sem travar ou expirar, pois procurarĆ” apenas por arquivos locais.
+
+### Obtendo modelos e tokenizers para uso offline
+
+Outra opĆ§Ć£o para usar o š¤ Transformers offline Ć© baixar os arquivos antes e depois apontar para o caminho local onde estĆ£o localizados. Existem trĆŖs maneiras de fazer isso:
+
+* Baixe um arquivo por meio da interface de usuĆ”rio do [Model Hub](https://huggingface.co/models) clicando no Ćcone ā.
+
+ ![download-icon](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/download-icon.png)
+
+
+* Use o pipeline do [`PreTrainedModel.from_pretrained`] e [`PreTrainedModel.save_pretrained`]:
+ 1. Baixa os arquivos previamente com [`PreTrainedModel.from_pretrained`]:
+
+ ```py
+ >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/T0_3B")
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0_3B")
+ ```
+
+
+ 2. Salve os arquivos em um diretĆ³rio especĆfico com [`PreTrainedModel.save_pretrained`]:
+
+ ```py
+ >>> tokenizer.save_pretrained("./your/path/bigscience_t0")
+ >>> model.save_pretrained("./your/path/bigscience_t0")
+ ```
+
+ 3. Quando estiver offline, acesse os arquivos com [`PreTrainedModel.from_pretrained`] do diretĆ³rio especificado:
+
+ ```py
+ >>> tokenizer = AutoTokenizer.from_pretrained("./your/path/bigscience_t0")
+ >>> model = AutoModel.from_pretrained("./your/path/bigscience_t0")
+ ```
+
+* Baixando arquivos programaticamente com a biblioteca [huggingface_hub](https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub):
+
+ 1. Instale a biblioteca [huggingface_hub](https://github.com/huggingface/huggingface_hub/tree/main/src/huggingface_hub) em seu ambiente virtual:
+
+ ```bash
+ python -m pip install huggingface_hub
+ ```
+
+ 2. Utiliza a funĆ§Ć£o [`hf_hub_download`](https://huggingface.co/docs/hub/adding-a-library#download-files-from-the-hub) para baixar um arquivo para um caminho especĆfico. Por exemplo, o comando a seguir baixarĆ” o arquivo `config.json` para o modelo [T0](https://huggingface.co/bigscience/T0_3B) no caminho desejado:
+
+ ```py
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> hf_hub_download(repo_id="bigscience/T0_3B", filename="config.json", cache_dir="./your/path/bigscience_t0")
+ ```
+
+Depois que o arquivo for baixado e armazenado no cachĆŖ local, especifique seu caminho local para carregĆ”-lo e usĆ”-lo:
+
+```py
+>>> from transformers import AutoConfig
+
+>>> config = AutoConfig.from_pretrained("./your/path/bigscience_t0/config.json")
+```
+
+
+
+Para obter mais detalhes sobre como baixar arquivos armazenados no Hub, consulte a seĆ§Ć£o [How to download files from the Hub](https://huggingface.co/docs/hub/how-to-downstream).
+
+
diff --git a/docs/source/pt/multilingual.mdx b/docs/source/pt/multilingual.mdx
new file mode 100644
index 00000000000000..4db9b54dab34fe
--- /dev/null
+++ b/docs/source/pt/multilingual.mdx
@@ -0,0 +1,191 @@
+
+
+# Modelos multilinguĆsticos para inferĆŖncia
+
+[[open-in-colab]]
+
+Existem vĆ”rios modelos multilinguĆsticos no š¤ Transformers e seus usos para inferĆŖncia diferem dos modelos monolĆngues.
+No entanto, nem *todos* os usos dos modelos multilĆngues sĆ£o tĆ£o diferentes.
+Alguns modelos, como o [bert-base-multilingual-uncased](https://huggingface.co/bert-base-multilingual-uncased),
+podem ser usados como se fossem monolĆngues. Este guia irĆ” te ajudar a usar modelos multilĆngues cujo uso difere
+para o propĆ³sito de inferĆŖncia.
+
+## XLM
+
+O XLM tem dez checkpoints diferentes dos quais apenas um Ć© monolĆngue.
+Os nove checkpoints restantes do modelo sĆ£o subdivididos em duas categorias:
+checkpoints que usam de language embeddings e os que nĆ£o.
+
+### XLM com language embeddings
+
+Os seguintes modelos de XLM usam language embeddings para especificar a linguagem utilizada para a inferĆŖncia.
+
+- `xlm-mlm-ende-1024` (Masked language modeling, English-German)
+- `xlm-mlm-enfr-1024` (Masked language modeling, English-French)
+- `xlm-mlm-enro-1024` (Masked language modeling, English-Romanian)
+- `xlm-mlm-xnli15-1024` (Masked language modeling, XNLI languages)
+- `xlm-mlm-tlm-xnli15-1024` (Masked language modeling + translation, XNLI languages)
+- `xlm-clm-enfr-1024` (Causal language modeling, English-French)
+- `xlm-clm-ende-1024` (Causal language modeling, English-German)
+
+Os language embeddings sĆ£o representados por um tensor de mesma dimensĆ£o que os `input_ids` passados ao modelo.
+Os valores destes tensores dependem do idioma utilizado e se identificam pelos atributos `lang2id` e `id2lang` do tokenizador.
+
+Neste exemplo, carregamos o checkpoint `xlm-clm-enfr-1024`(Causal language modeling, English-French):
+
+```py
+>>> import torch
+>>> from transformers import XLMTokenizer, XLMWithLMHeadModel
+
+>>> tokenizer = XLMTokenizer.from_pretrained("xlm-clm-enfr-1024")
+>>> model = XLMWithLMHeadModel.from_pretrained("xlm-clm-enfr-1024")
+```
+
+O atributo `lang2id` do tokenizador mostra os idiomas deste modelo e seus ids:
+
+```py
+>>> print(tokenizer.lang2id)
+{'en': 0, 'fr': 1}
+```
+
+Em seguida, cria-se um input de exemplo:
+
+```py
+>>> input_ids = torch.tensor([tokenizer.encode("Wikipedia was used to")]) # batch size of 1
+```
+
+Estabelece-se o id do idioma, por exemplo `"en"`, e utiliza-se o mesmo para definir a language embedding.
+A language embedding Ć© um tensor preenchido com `0`, que Ć© o id de idioma para o inglĆŖs.
+Este tensor deve ser do mesmo tamanho que os `input_ids`.
+
+```py
+>>> language_id = tokenizer.lang2id["en"] # 0
+>>> langs = torch.tensor([language_id] * input_ids.shape[1]) # torch.tensor([0, 0, 0, ..., 0])
+
+>>> # We reshape it to be of size (batch_size, sequence_length)
+>>> langs = langs.view(1, -1) # is now of shape [1, sequence_length] (we have a batch size of 1)
+```
+
+Agora vocĆŖ pode passar os `input_ids` e a language embedding ao modelo:
+
+```py
+>>> outputs = model(input_ids, langs=langs)
+```
+
+O script [run_generation.py](https://github.com/huggingface/transformers/tree/master/examples/pytorch/text-generation/run_generation.py) pode gerar um texto com language embeddings utilizando os checkpoints `xlm-clm`.
+
+### XLM sem language embeddings
+
+Os seguintes modelos XLM nĆ£o requerem o uso de language embeddings durante a inferĆŖncia:
+
+- `xlm-mlm-17-1280` (Modelagem de linguagem com mƔscara, 17 idiomas)
+- `xlm-mlm-100-1280` (Modelagem de linguagem com mƔscara, 100 idiomas)
+
+Estes modelos sĆ£o utilizados para representaƧƵes genĆ©ricas de frase diferentemente dos checkpoints XLM anteriores.
+
+## BERT
+
+Os seguintes modelos do BERT podem ser utilizados para tarefas multilinguĆsticas:
+
+- `bert-base-multilingual-uncased` (Modelagem de linguagem com mĆ”scara + PrevisĆ£o de frases, 102 idiomas)
+- `bert-base-multilingual-cased` (Modelagem de linguagem com mĆ”scara + PrevisĆ£o de frases, 104 idiomas)
+
+Estes modelos nĆ£o requerem language embeddings durante a inferĆŖncia. Devem identificar a linguagem a partir
+do contexto e realizar a inferĆŖncia em sequĆŖncia.
+
+## XLM-RoBERTa
+
+Os seguintes modelos do XLM-RoBERTa podem ser utilizados para tarefas multilinguĆsticas:
+
+- `xlm-roberta-base` (Modelagem de linguagem com mƔscara, 100 idiomas)
+- `xlm-roberta-large` Modelagem de linguagem com mƔscara, 100 idiomas)
+
+O XLM-RoBERTa foi treinado com 2,5 TB de dados do CommonCrawl recƩm-criados e testados em 100 idiomas.
+Proporciona fortes vantagens sobre os modelos multilinguĆsticos publicados anteriormente como o mBERT e o XLM em tarefas
+subsequentes como a classificaĆ§Ć£o, a rotulagem de sequĆŖncias e Ć respostas a perguntas.
+
+## M2M100
+
+Os seguintes modelos de M2M100 podem ser utilizados para traduƧƵes multilinguĆsticas:
+
+- `facebook/m2m100_418M` (TraduĆ§Ć£o)
+- `facebook/m2m100_1.2B` (TraduĆ§Ć£o)
+
+Neste exemplo, o checkpoint `facebook/m2m100_418M` Ć© carregado para traduzir do mandarim ao inglĆŖs. Ć possĆvel
+estabelecer o idioma de origem no tokenizador:
+
+```py
+>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
+
+>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger."
+>>> chinese_text = "äøč¦ęęå·«åø«ēäŗå, å ēŗä»åęÆå¾®å¦ē, å¾åæ«å°±ęē¼ę."
+
+>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="zh")
+>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
+```
+
+TokenizaĆ§Ć£o do texto:
+
+```py
+>>> encoded_zh = tokenizer(chinese_text, return_tensors="pt")
+```
+
+O M2M100 forƧa o id do idioma de destino como o primeiro token gerado para traduzir ao idioma de destino.
+Ć definido o `forced_bos_token_id` como `en` no mĆ©todo `generate` para traduzir ao inglĆŖs.
+
+```py
+>>> generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
+>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+'Do not interfere with the matters of the witches, because they are delicate and will soon be angry.'
+```
+
+## MBart
+
+Os seguintes modelos do MBart podem ser utilizados para traduĆ§Ć£o multilinguĆstica:
+
+- `facebook/mbart-large-50-one-to-many-mmt` (TraduĆ§Ć£o automĆ”tica multilinguĆstica de um a vĆ”rios, 50 idiomas)
+- `facebook/mbart-large-50-many-to-many-mmt` (TraduĆ§Ć£o automĆ”tica multilinguĆstica de vĆ”rios a vĆ”rios, 50 idiomas)
+- `facebook/mbart-large-50-many-to-one-mmt` (TraduĆ§Ć£o automĆ”tica multilinguĆstica vĆ”rios a um, 50 idiomas)
+- `facebook/mbart-large-50` (TraduĆ§Ć£o multilinguĆstica, 50 idiomas)
+- `facebook/mbart-large-cc25`
+
+Neste exemplo, carrega-se o checkpoint `facebook/mbart-large-50-many-to-many-mmt` para traduzir do finlandĆŖs ao inglĆŖs.
+Pode-se definir o idioma de origem no tokenizador:
+
+```py
+>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger."
+>>> fi_text = "ĆlƤ sekaannu velhojen asioihin, sillƤ ne ovat hienovaraisia ja nopeasti vihaisia."
+
+>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", src_lang="fi_FI")
+>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
+```
+
+Tokenizando o texto:
+
+```py
+>>> encoded_en = tokenizer(en_text, return_tensors="pt")
+```
+
+O MBart forƧa o id do idioma de destino como o primeiro token gerado para traduzir ao idioma de destino.
+Ć definido o `forced_bos_token_id` como `en` no mĆ©todo `generate` para traduzir ao inglĆŖs.
+
+```py
+>>> generated_tokens = model.generate(**encoded_en, forced_bos_token_id=tokenizer.lang_code_to_id("en_XX"))
+>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+"Don't interfere with the wizard's affairs, because they are subtle, will soon get angry."
+```
+
+Se estiver usando o checkpoint `facebook/mbart-large-50-many-to-one-mmt` nĆ£o serĆ” necessĆ”rio forƧar o id do idioma de destino
+como sendo o primeiro token generado, caso contrƔrio a usagem Ʃ a mesma.
diff --git a/docs/source/pt/pipeline_tutorial.mdx b/docs/source/pt/pipeline_tutorial.mdx
new file mode 100644
index 00000000000000..05c9e87bc2f52c
--- /dev/null
+++ b/docs/source/pt/pipeline_tutorial.mdx
@@ -0,0 +1,153 @@
+
+
+# Pipelines para inferĆŖncia
+
+Um [pipeline] simplifica o uso dos modelos no [Model Hub](https://huggingface.co/models) para a inferĆŖncia de uma diversidade de tarefas,
+como a geraĆ§Ć£o de texto, a segmentaĆ§Ć£o de imagens e a classificaĆ§Ć£o de Ć”udio.
+Inclusive, se nĆ£o tem experiĆŖncia com alguma modalidade especĆfica ou nĆ£o compreende o cĆ³digo que forma os modelos,
+pode usar eles mesmo assim com o [pipeline]! Este tutorial te ensinarĆ” a:
+
+* Utilizar um [`pipeline`] para inferĆŖncia.
+* Utilizar um tokenizador ou model especĆfico.
+* Utilizar um [`pipeline`] para tarefas de Ć”udio e visĆ£o computacional.
+
+
+
+ Acesse a documentaĆ§Ć£o do [`pipeline`] para obter uma lista completa de tarefas possĆveis.
+
+
+
+## Uso do pipeline
+
+Mesmo que cada tarefa tenha um [`pipeline`] associado, Ć© mais simples usar a abstraĆ§Ć£o geral do [`pipeline`] que
+contĆ©m todos os pipelines das tarefas mais especĆficas.
+O [`pipeline`] carrega automaticamenta um modelo predeterminado e um tokenizador com capacidade de inferĆŖncia para sua
+tarefa.
+
+1. Comece carregando um [`pipeline`] e especifique uma tarefa de inferĆŖncia:
+
+```py
+>>> from transformers import pipeline
+
+>>> generator = pipeline(task="text-generation")
+```
+
+2. Passe seu dado de entrada, no caso um texto, ao [`pipeline`]:
+
+```py
+>>> generator("Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone")
+[{'generated_text': 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Seven for the Iron-priests at the door to the east, and thirteen for the Lord Kings at the end of the mountain'}]
+```
+
+Se tiver mais de uma entrada, passe-a como uma lista:
+
+```py
+>>> generator(
+... [
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone",
+... "Nine for Mortal Men, doomed to die, One for the Dark Lord on his dark throne",
+... ]
+... )
+```
+
+Qualquer parĆ¢metro adicional para a sua tarefa tambĆ©m pode ser incluĆdo no [`pipeline`]. A tarefa `text-generation` tem um mĆ©todo
+[`~generation_utils.GenerationMixin.generate`] com vĆ”rios parĆ¢metros para controlar a saĆda.
+Por exemplo, se quiser gerar mais de uma saĆda, defina-a no parĆ¢metro `num_return_sequences`:
+
+```py
+>>> generator(
+... "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone",
+... num_return_sequences=2,
+... )
+```
+
+### Selecionando um modelo e um tokenizador
+
+O [`pipeline`] aceita qualquer modelo do [Model Hub](https://huggingface.co/models). HĆ” rĆ³tulos adicionais no Model Hub
+que te permitem filtrar pelo modelo que gostaria de usar para sua tarefa. Uma vez que tiver escolhido o modelo apropriado,
+carregue-o com as classes `AutoModelFor` e [`AutoTokenizer'] correspondentes. Por exemplo, carregue a classe [`AutoModelForCausalLM`]
+para uma tarefa de modelagem de linguagem causal:
+
+```py
+>>> from transformers import AutoTokenizer, AutoModelForCausalLM
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
+>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+```
+
+Crie uma [`pipeline`] para a sua tarefa e especifĆque o modelo e o tokenizador que foram carregados:
+
+```py
+>>> from transformers import pipeline
+
+>>> generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
+```
+
+Passe seu texto de entrada ao [`pipeline`] para gerar algum texto:
+
+```py
+>>> generator("Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone")
+[{'generated_text': 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Seven for the Dragon-lords (for them to rule in a world ruled by their rulers, and all who live within the realm'}]
+```
+
+## Pipeline de audio
+
+A flexibilidade do [`pipeline`] significa que tambĆ©m pode-se extender Ć s tarefas de Ć”udio.
+La flexibilidad de [`pipeline`] significa que tambiƩn se puede extender a tareas de audio.
+
+Por exemplo, classifiquemos a emoĆ§Ć£o de um breve fragmento do famoso discurso de John F. Kennedy /home/rzimmerdev/dev/transformers/docs/source/pt/pipeline_tutorial.mdx
+Encontre um modelo de [audio classification](https://huggingface.co/models?pipeline_tag=audio-classification) para
+reconhecimento de emoƧƵes no Model Hub e carregue-o usando o [`pipeline`]:
+
+```py
+>>> from transformers import pipeline
+
+>>> audio_classifier = pipeline(
+... task="audio-classification", model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
+... )
+```
+
+Passe o arquivo de Ɣudio ao [`pipeline`]:
+
+```py
+>>> audio_classifier("jfk_moon_speech.wav")
+[{'label': 'calm', 'score': 0.13856211304664612},
+ {'label': 'disgust', 'score': 0.13148026168346405},
+ {'label': 'happy', 'score': 0.12635163962841034},
+ {'label': 'angry', 'score': 0.12439591437578201},
+ {'label': 'fearful', 'score': 0.12404385954141617}]
+```
+
+## Pipeline de visĆ£o computacional
+
+Finalmente, utilizar um [`pipeline`] para tarefas de visĆ£o Ć© praticamente a mesma coisa.
+Especifique a sua tarefa de visĆ£o e passe a sua imagem ao classificador.
+A imagem pode ser um link ou uma rota local Ć imagem. Por exemplo, que espĆ©cie de gato estĆ” presente na imagem?
+
+![pipeline-cat-chonk](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg)
+
+```py
+>>> from transformers import pipeline
+
+>>> vision_classifier = pipeline(task="image-classification")
+>>> vision_classifier(
+... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+... )
+[{'label': 'lynx, catamount', 'score': 0.4403027892112732},
+ {'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 'score': 0.03433405980467796},
+ {'label': 'snow leopard, ounce, Panthera uncia',
+ 'score': 0.032148055732250214},
+ {'label': 'Egyptian cat', 'score': 0.02353910356760025},
+ {'label': 'tiger cat', 'score': 0.023034192621707916}]
+```
diff --git a/docs/source/pt/quicktour.mdx b/docs/source/pt/quicktour.mdx
new file mode 100644
index 00000000000000..3c00a64b6652e8
--- /dev/null
+++ b/docs/source/pt/quicktour.mdx
@@ -0,0 +1,391 @@
+
+
+# Tour rƔpido
+
+[[open-in-colab]]
+
+Comece a trabalhar com š¤ Transformers! Comece usando [`pipeline`] para rĆ”pida inferĆŖncia e facilmente carregue um modelo prĆ©-treinado e um tokenizer com [AutoClass](./model_doc/auto) para resolver tarefas de texto, visĆ£o ou Ć”udio.
+
+
+
+Todos os exemplos de cĆ³digo apresentados na documentaĆ§Ć£o tĆŖm um botĆ£o no canto superior direito para escolher se vocĆŖ deseja ocultar ou mostrar o cĆ³digo no Pytorch ou no TensorFlow. Caso contrĆ”rio, Ć© esperado que funcione para ambos back-ends sem nenhuma alteraĆ§Ć£o.
+
+
+
+## Pipeline
+
+[`pipeline`] Ʃ a maneira mais fƔcil de usar um modelo prƩ-treinado para uma dada tarefa.
+
+
+
+A [`pipeline`] apoia diversas tarefas fora da caixa:
+
+**Texto**:
+* AnƔlise sentimental: classifica a polaridade de um texto.
+* GeraĆ§Ć£o de texto (em InglĆŖs): gera texto a partir de uma entrada.
+* Reconhecimento de entidade mencionada: legenda cada palavra com uma classe que a representa (pessoa, data, local, etc...)
+* Respostas: extrai uma resposta dado algum contexto e uma questĆ£o
+* MƔscara de preenchimento: preenche o espaƧo, dado um texto com mƔscaras de palavras.
+* SumarizaĆ§Ć£o: gera o resumo de um texto longo ou documento.
+* TraduĆ§Ć£o: traduz texto para outra lĆngua.
+* ExtraĆ§Ć£o de caracterĆsticas: cria um tensor que representa o texto.
+
+**Imagem**:
+* ClassificaĆ§Ć£o de imagens: classifica uma imagem.
+* SegmentaĆ§Ć£o de imagem: classifica cada pixel da imagem.
+* DetecĆ§Ć£o de objetos: detecta objetos em uma imagem.
+
+**Audio**:
+* ClassficaĆ§Ć£o de Ć”udio: legenda um trecho de Ć”udio fornecido.
+* Reconhecimento de fala automƔtico: transcreve audio em texto.
+
+
+
+Para mais detalhes sobre a [`pipeline`] e tarefas associadas, siga a documentaĆ§Ć£o [aqui](./main_classes/pipelines).
+
+
+
+### Uso da pipeline
+
+No exemplo a seguir, vocĆŖ usarĆ” [`pipeline`] para anĆ”lise sentimental.
+
+Instale as seguintes dependĆŖncias se vocĆŖ ainda nĆ£o o fez:
+
+
+
+
+```bash
+pip install torch
+```
+
+
+```bash
+pip install tensorflow
+```
+
+
+
+Importe [`pipeline`] e especifique a tarefa que deseja completar:
+
+```py
+>>> from transformers import pipeline
+
+>>> classifier = pipeline("sentiment-analysis")
+```
+
+A pipeline baixa and armazena um [modelo prĆ©-treinado](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) padrĆ£o e tokenizer para anĆ”lise sentimental. Agora vocĆŖ pode usar `classifier` no texto alvo:
+
+```py
+>>> classifier("We are very happy to show you the š¤ Transformers library.")
+[{'label': 'POSITIVE', 'score': 0.9998}]
+```
+
+Para mais de uma sentenƧa, passe uma lista para a [`pipeline`], a qual retornarƔ uma lista de dicionƔrios:
+
+```py
+>>> results = classifier(["We are very happy to show you the š¤ Transformers library.", "We hope you don't hate it."])
+>>> for result in results:
+... print(f"label: {result['label']}, with score: {round(result['score'], 4)}")
+label: POSITIVE, with score: 0.9998
+label: NEGATIVE, with score: 0.5309
+```
+
+A [`pipeline`] tambĆ©m pode iterar sobre um Dataset inteiro. Comece instalando a biblioteca de [š¤ Datasets](https://huggingface.co/docs/datasets/):
+
+```bash
+pip install datasets
+```
+
+Crie uma [`pipeline`] com a tarefa que deseja resolver e o modelo que deseja usar.
+
+```py
+>>> import torch
+>>> from transformers import pipeline
+
+>>> speech_recognizer = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
+```
+
+A seguir, carregue uma base de dados (confira a š¤ [IniciaĆ§Ć£o em Datasets](https://huggingface.co/docs/datasets/quickstart.html) para mais detalhes) que vocĆŖ gostaria de iterar sobre. Por exemplo, vamos carregar o dataset [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14):
+
+```py
+>>> from datasets import load_dataset, Audio
+
+>>> dataset = load_dataset("PolyAI/minds14", name="en-US", split="train") # doctest: +IGNORE_RESULT
+```
+
+Precisamos garantir que a taxa de amostragem do conjunto de dados corresponda Ć taxa de amostragem em que o facebook/wav2vec2-base-960h foi treinado.
+
+```py
+>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=speech_recognizer.feature_extractor.sampling_rate))
+```
+
+Os arquivos de Ć”udio sĆ£o carregados e re-amostrados automaticamente ao chamar a coluna `"audio"`.
+Vamos extrair as arrays de formas de onda originais das primeiras 4 amostras e passĆ”-las como uma lista para o pipeline:
+
+```py
+>>> result = speech_recognizer(dataset[:4]["audio"])
+>>> print([d["text"] for d in result])
+['I WOULD LIKE TO SET UP A JOINT ACCOUNT WITH MY PARTNER HOW DO I PROCEED WITH DOING THAT', "FONDERING HOW I'D SET UP A JOIN TO HET WITH MY WIFE AND WHERE THE AP MIGHT BE", "I I'D LIKE TOY SET UP A JOINT ACCOUNT WITH MY PARTNER I'M NOT SEEING THE OPTION TO DO IT ON THE APSO I CALLED IN TO GET SOME HELP CAN I JUST DO IT OVER THE PHONE WITH YOU AND GIVE YOU THE INFORMATION OR SHOULD I DO IT IN THE AP AND I'M MISSING SOMETHING UQUETTE HAD PREFERRED TO JUST DO IT OVER THE PHONE OF POSSIBLE THINGS", 'HOW DO I TURN A JOIN A COUNT']
+```
+
+Para um conjunto de dados maior onde as entradas sĆ£o maiores (como em fala ou visĆ£o), serĆ” necessĆ”rio passar um gerador em vez de uma lista que carregue todas as entradas na memĆ³ria. Consulte a [documentaĆ§Ć£o do pipeline](./main_classes/pipelines) para mais informaƧƵes.
+
+### Use outro modelo e tokenizer na pipeline
+
+A [`pipeline`] pode acomodar qualquer modelo do [Model Hub](https://huggingface.co/models), facilitando sua adaptaĆ§Ć£o para outros casos de uso. Por exemplo, se vocĆŖ quiser um modelo capaz de lidar com texto em francĆŖs, use as tags no Model Hub para filtrar um modelo apropriado. O principal resultado filtrado retorna um [modelo BERT](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) bilĆngue ajustado para anĆ”lise de sentimentos. Ćtimo, vamos usar este modelo!
+
+```py
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+```
+
+
+
+Use o [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] para carregar o modelo prƩ-treinado e seu tokenizer associado (mais em `AutoClass` abaixo):
+
+```py
+>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+>>> model = AutoModelForSequenceClassification.from_pretrained(model_name)
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+```
+
+
+
+Use o [`TFAutoModelForSequenceClassification`] and [`AutoTokenizer`] para carregar o modelo prƩ-treinado e o tokenizer associado (mais em `TFAutoClass` abaixo):
+
+```py
+>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
+
+>>> model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+```
+
+
+
+EntĆ£o vocĆŖ pode especificar o modelo e o tokenizador na [`pipeline`] e aplicar o `classifier` no seu texto alvo:
+
+```py
+>>> classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
+>>> classifier("Nous sommes trĆØs heureux de vous prĆ©senter la bibliothĆØque š¤ Transformers.")
+[{'label': '5 stars', 'score': 0.7273}]
+```
+
+Se vocĆŖ nĆ£o conseguir achar um modelo para o seu caso de uso, precisarĆ” usar fine-tune em um modelo prĆ©-treinado nos seus dados. Veja nosso [tutorial de fine-tuning](./training) para descobrir como. Finalmente, depois que vocĆŖ tiver usado esse processo em seu modelo, considere compartilhĆ”-lo conosco (veja o tutorial [aqui](./model_sharing)) na plataforma Model Hub afim de democratizar NLP! š¤
+
+## AutoClass
+
+
+
+Por baixo dos panos, as classes [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] trabalham juntas para fortificar o [`pipeline`]. Um [AutoClass](./model_doc/auto) Ʃ um atalho que automaticamente recupera a arquitetura de um modelo prƩ-treinado a partir de seu nome ou caminho. Basta selecionar a `AutoClass` apropriada para sua tarefa e seu tokenizer associado com [`AutoTokenizer`].
+
+Vamos voltar ao nosso exemplo e ver como vocĆŖ pode usar a `AutoClass` para replicar os resultados do [`pipeline`].
+
+### AutoTokenizer
+
+Um tokenizer Ć© responsĆ”vel por prĆ©-processar o texto em um formato que seja compreensĆvel para o modelo. Primeiro, o tokenizer dividirĆ” o texto em palavras chamadas *tokens*. Existem vĆ”rias regras que regem o processo de tokenizaĆ§Ć£o, incluindo como dividir uma palavra e em que nĆvel (saiba mais sobre tokenizaĆ§Ć£o [aqui](./tokenizer_summary)). A coisa mais importante a lembrar, porĆ©m, Ć© que vocĆŖ precisa instanciar o tokenizer com o mesmo nome do modelo para garantir que estĆ” usando as mesmas regras de tokenizaĆ§Ć£o com as quais um modelo foi prĆ©-treinado.
+
+Carregue um tokenizer com [`AutoTokenizer`]:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
+```
+
+Em seguida, o tokenizer converte os tokens em nĆŗmeros para construir um tensor como entrada para o modelo. Isso Ć© conhecido como o *vocabulĆ”rio* do modelo.
+
+Passe o texto para o tokenizer:
+
+```py
+>>> encoding = tokenizer("We are very happy to show you the š¤ Transformers library.")
+>>> print(encoding)
+{'input_ids': [101, 11312, 10320, 12495, 19308, 10114, 11391, 10855, 10103, 100, 58263, 13299, 119, 102],
+ 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+O tokenizer retornarƔ um dicionƔrio contendo:
+
+* [input_ids](./glossary#input-ids): representaƧƵes numƩricas de seus tokens.
+* [atttention_mask](.glossary#attention-mask): indica quais tokens devem ser atendidos.
+
+Assim como o [`pipeline`], o tokenizer aceitarƔ uma lista de entradas. AlƩm disso, o tokenizer tambƩm pode preencher e truncar o texto para retornar um lote com comprimento uniforme:
+
+
+
+```py
+>>> pt_batch = tokenizer(
+... ["We are very happy to show you the š¤ transformers library.", "We hope you don't hate it."],
+... padding=True,
+... truncation=True,
+... max_length=512,
+... return_tensors="pt",
+... )
+```
+
+
+```py
+>>> tf_batch = tokenizer(
+... ["We are very happy to show you the š¤ Transformers library.", "We hope you don't hate it."],
+... padding=True,
+... truncation=True,
+... max_length=512,
+... return_tensors="tf",
+... )
+```
+
+
+
+Leia o tutorial de [prĆ©-processamento](./prĆ©-processamento) para obter mais detalhes sobre tokenizaĆ§Ć£o.
+
+### AutoModel
+
+
+
+š¤ Transformers fornecem uma maneira simples e unificada de carregar instĆ¢ncias prĆ©-treinadas. Isso significa que vocĆŖ pode carregar um [`AutoModel`] como carregaria um [`AutoTokenizer`]. A Ćŗnica diferenƧa Ć© selecionar o [`AutoModel`] correto para a tarefa. Como vocĆŖ estĆ” fazendo classificaĆ§Ć£o de texto ou sequĆŖncia, carregue [`AutoModelForSequenceClassification`]:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification
+
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(model_name)
+```
+
+
+
+Veja o [sumƔrio de tarefas](./task_summary) para qual classe de [`AutoModel`] usar para cada tarefa.
+
+
+
+Agora vocĆŖ pode passar seu grupo de entradas prĆ©-processadas diretamente para o modelo. VocĆŖ apenas tem que descompactar o dicionĆ”rio usando `**`:
+
+```py
+>>> pt_outputs = pt_model(**pt_batch)
+```
+
+O modelo gera as ativaƧƵes finais no atributo `logits`. Aplique a funĆ§Ć£o softmax aos `logits` para recuperar as probabilidades:
+
+```py
+>>> from torch import nn
+
+>>> pt_predictions = nn.functional.softmax(pt_outputs.logits, dim=-1)
+>>> print(pt_predictions)
+tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
+ [0.2084, 0.1826, 0.1969, 0.1755, 0.2365]], grad_fn=)
+```
+
+
+š¤ Transformers fornecem uma maneira simples e unificada de carregar instĆ¢ncias prĆ©-treinadas. Isso significa que vocĆŖ pode carregar um [`TFAutoModel`] como carregaria um [`AutoTokenizer`]. A Ćŗnica diferenƧa Ć© selecionar o [`TFAutoModel`] correto para a tarefa. Como vocĆŖ estĆ” fazendo classificaĆ§Ć£o de texto ou sequĆŖncia, carregue [`TFAutoModelForSequenceClassification`]:
+
+```py
+>>> from transformers import TFAutoModelForSequenceClassification
+
+>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
+```
+
+
+
+Veja o [sumƔrio de tarefas](./task_summary) para qual classe de [`AutoModel`] usar para cada tarefa.
+
+
+
+Agora vocĆŖ pode passar seu grupo de entradas prĆ©-processadas diretamente para o modelo atravĆ©s da passagem de chaves de dicionĆ”rios ao tensor.
+
+```py
+>>> tf_outputs = tf_model(tf_batch)
+```
+
+O modelo gera as ativaƧƵes finais no atributo `logits`. Aplique a funĆ§Ć£o softmax aos `logits` para recuperar as probabilidades:
+
+```py
+>>> import tensorflow as tf
+
+>>> tf_predictions = tf.nn.softmax(tf_outputs.logits, axis=-1)
+>>> tf_predictions # doctest: +IGNORE_RESULT
+```
+
+
+
+
+
+Todos os modelos de š¤ Transformers (PyTorch ou TensorFlow) geram tensores *antes* da funĆ§Ć£o de ativaĆ§Ć£o final (como softmax) pois essa funĆ§Ć£o algumas vezes Ć© fundida com a perda.
+
+
+
+
+Os modelos sĆ£o um standard [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) ou um [`tf.keras.Model`](https: //www.tensorflow.org/api_docs/python/tf/keras/Model) para que vocĆŖ possa usĆ”-los em seu loop de treinamento habitual. No entanto, para facilitar as coisas, š¤ Transformers fornece uma classe [`Trainer`] para PyTorch que adiciona funcionalidade para treinamento distribuĆdo, precisĆ£o mista e muito mais. Para o TensorFlow, vocĆŖ pode usar o mĆ©todo `fit` de [Keras](https://keras.io/). Consulte o [tutorial de treinamento](./training) para obter mais detalhes.
+
+
+
+As saĆdas do modelo š¤ Transformers sĆ£o classes de dados especiais para que seus atributos sejam preenchidos automaticamente em um IDE.
+As saĆdas do modelo tambĆ©m se comportam como uma tupla ou um dicionĆ”rio (por exemplo, vocĆŖ pode indexar com um inteiro, uma parte ou uma string), caso em que os atributos `None` sĆ£o ignorados.
+
+
+
+### Salvar um modelo
+
+
+
+Uma vez que seu modelo estiver afinado, vocĆŖ pode salvĆ”-lo com seu Tokenizer usando [`PreTrainedModel.save_pretrained`]:
+
+```py
+>>> pt_save_directory = "./pt_save_pretrained"
+>>> tokenizer.save_pretrained(pt_save_directory) # doctest: +IGNORE_RESULT
+>>> pt_model.save_pretrained(pt_save_directory)
+```
+
+Quando vocĆŖ estiver pronto para usĆ”-lo novamente, recarregue com [`PreTrainedModel.from_pretrained`]:
+
+```py
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained("./pt_save_pretrained")
+```
+
+
+Uma vez que seu modelo estiver afinado, vocĆŖ pode salvĆ”-lo com seu Tokenizer usando [`TFPreTrainedModel.save_pretrained`]:
+
+```py
+>>> tf_save_directory = "./tf_save_pretrained"
+>>> tokenizer.save_pretrained(tf_save_directory) # doctest: +IGNORE_RESULT
+>>> tf_model.save_pretrained(tf_save_directory)
+```
+
+Quando vocĆŖ estiver pronto para usĆ”-lo novamente, recarregue com [`TFPreTrainedModel.from_pretrained`]
+
+```py
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained("./tf_save_pretrained")
+```
+
+
+
+Um recurso particularmente interessante dos š¤ Transformers Ć© a capacidade de salvar um modelo e recarregĆ”-lo como um modelo PyTorch ou TensorFlow. Use `from_pt` ou `from_tf` para converter o modelo de um framework para outro:
+
+
+
+```py
+>>> from transformers import AutoModel
+
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+```
+
+
+```py
+>>> from transformers import TFAutoModel
+
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+```
+
+
\ No newline at end of file
diff --git a/docs/source/pt/tasks/sequence_classification.mdx b/docs/source/pt/tasks/sequence_classification.mdx
new file mode 100644
index 00000000000000..7c443e700d4edd
--- /dev/null
+++ b/docs/source/pt/tasks/sequence_classification.mdx
@@ -0,0 +1,212 @@
+
+
+# ClassificaĆ§Ć£o de texto
+
+
+
+A classificaĆ§Ć£o de texto Ć© uma tarefa comum de NLP que atribui um rĆ³tulo ou classe a um texto. Existem muitas aplicaƧƵes prĆ”ticas de classificaĆ§Ć£o de texto amplamente utilizadas em produĆ§Ć£o por algumas das maiores empresas da atualidade. Uma das formas mais populares de classificaĆ§Ć£o de texto Ć© a anĆ”lise de sentimento, que atribui um rĆ³tulo como positivo, negativo ou neutro a um texto.
+
+Este guia mostrarĆ” como realizar o fine-tuning do [DistilBERT](https://huggingface.co/distilbert-base-uncased) no conjunto de dados [IMDb](https://huggingface.co/datasets/imdb) para determinar se a crĆtica de filme Ć© positiva ou negativa.
+
+
+
+Consulte a [pĆ”gina de tarefas de classificaĆ§Ć£o de texto](https://huggingface.co/tasks/text-classification) para obter mais informaƧƵes sobre outras formas de classificaĆ§Ć£o de texto e seus modelos, conjuntos de dados e mĆ©tricas associados.
+
+
+
+## Carregue o conjunto de dados IMDb
+
+Carregue o conjunto de dados IMDb utilizando a biblioteca š¤ Datasets:
+
+```py
+>>> from datasets import load_dataset
+
+>>> imdb = load_dataset("imdb")
+```
+
+Em seguida, dĆŖ uma olhada em um exemplo:
+
+```py
+>>> imdb["test"][0]
+{
+ "label": 0,
+ "text": "I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match the background, and painfully one-dimensional characters cannot be overcome with a 'sci-fi' setting. (I'm sure there are those of you out there who think Babylon 5 is good sci-fi TV. It's not. It's clichƩd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It's really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it's rubbish as they have to always say \"Gene Roddenberry's Earth...\" otherwise people would not continue watching. Roddenberry's ashes must be turning in their orbit as this dull, cheap, poorly edited (watching it without advert breaks really brings this home) trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring him back as another actor. Jeeez! Dallas all over again.",
+}
+```
+
+Existem dois campos neste dataset:
+
+- `text`: uma string contendo o texto da crĆtica do filme.
+- `label`: um valor que pode ser `0` para uma crĆtica negativa ou `1` para uma crĆtica positiva.
+
+## PrƩ-processamento dos dados
+
+Carregue o tokenizador do DistilBERT para processar o campo `text`:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
+```
+
+Crie uma funĆ§Ć£o de prĆ©-processamento para tokenizar o campo `text` e truncar as sequĆŖncias para que nĆ£o sejam maiores que o comprimento mĆ”ximo de entrada do DistilBERT:
+
+```py
+>>> def preprocess_function(examples):
+... return tokenizer(examples["text"], truncation=True)
+```
+
+Use a funĆ§Ć£o [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) do š¤ Datasets para aplicar a funĆ§Ć£o de prĆ©-processamento em todo o conjunto de dados. VocĆŖ pode acelerar a funĆ§Ć£o `map` definindo `batched=True` para processar vĆ”rios elementos do conjunto de dados de uma sĆ³ vez:
+
+```py
+tokenized_imdb = imdb.map(preprocess_function, batched=True)
+```
+
+Use o [`DataCollatorWithPadding`] para criar um batch de exemplos. Ele tambĆ©m *preencherĆ” dinamicamente* seu texto atĆ© o comprimento do elemento mais longo em seu batch, para que os exemplos do batch tenham um comprimento uniforme. Embora seja possĆvel preencher seu texto com a funĆ§Ć£o `tokenizer` definindo `padding=True`, o preenchimento dinĆ¢mico utilizando um data collator Ć© mais eficiente.
+
+
+
+```py
+>>> from transformers import DataCollatorWithPadding
+
+>>> data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
+```
+
+
+```py
+>>> from transformers import DataCollatorWithPadding
+
+>>> data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
+```
+
+
+
+## Train
+
+
+
+Carregue o DistilBERT com [`AutoModelForSequenceClassification`] junto com o nĆŗmero de rĆ³tulos esperados:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
+
+>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
+```
+
+
+
+Se vocĆŖ nĆ£o estiver familiarizado com o fine-tuning de um modelo com o [`Trainer`], dĆŖ uma olhada no tutorial bĆ”sico [aqui](../training#finetune-with-trainer)!
+
+
+
+Nesse ponto, restam apenas trĆŖs passos:
+
+1. Definir seus hiperparĆ¢metros de treinamento em [`TrainingArguments`].
+2. Passar os argumentos de treinamento para o [`Trainer`] junto com o modelo, conjunto de dados, tokenizador e o data collator.
+3. Chamar a funĆ§Ć£o [`~Trainer.train`] para executar o fine-tuning do seu modelo.
+
+```py
+>>> training_args = TrainingArguments(
+... output_dir="./results",
+... learning_rate=2e-5,
+... per_device_train_batch_size=16,
+... per_device_eval_batch_size=16,
+... num_train_epochs=5,
+... weight_decay=0.01,
+... )
+
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=tokenized_imdb["train"],
+... eval_dataset=tokenized_imdb["test"],
+... tokenizer=tokenizer,
+... data_collator=data_collator,
+... )
+
+>>> trainer.train()
+```
+
+
+
+O [`Trainer`] aplicarĆ” o preenchimento dinĆ¢mico por padrĆ£o quando vocĆŖ definir o argumento `tokenizer` dele. Nesse caso, vocĆŖ nĆ£o precisa especificar um data collator explicitamente.
+
+
+
+
+Para executar o fine-tuning de um modelo no TensorFlow, comece convertendo seu conjunto de dados para o formato `tf.data.Dataset` com [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Nessa execuĆ§Ć£o vocĆŖ deverĆ” especificar as entradas e rĆ³tulos (no parĆ¢metro `columns`), se deseja embaralhar o conjunto de dados, o tamanho do batch e o data collator:
+
+```py
+>>> tf_train_set = tokenized_imdb["train"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "label"],
+... shuffle=True,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+
+>>> tf_validation_set = tokenized_imdb["test"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "label"],
+... shuffle=False,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+```
+
+
+
+Se vocĆŖ nĆ£o estiver familiarizado com o fine-tuning de um modelo com o Keras, dĆŖ uma olhada no tutorial bĆ”sico [aqui](training#finetune-with-keras)!
+
+
+
+Configure o otimizador e alguns hiperparĆ¢metros de treinamento:
+
+```py
+>>> from transformers import create_optimizer
+>>> import tensorflow as tf
+
+>>> batch_size = 16
+>>> num_epochs = 5
+>>> batches_per_epoch = len(tokenized_imdb["train"]) // batch_size
+>>> total_train_steps = int(batches_per_epoch * num_epochs)
+>>> optimizer, schedule = create_optimizer(init_lr=2e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
+```
+
+Carregue o DistilBERT com [`TFAutoModelForSequenceClassification`] junto com o nĆŗmero de rĆ³tulos esperados:
+
+```py
+>>> from transformers import TFAutoModelForSequenceClassification
+
+>>> model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
+```
+
+Configure o modelo para treinamento com o mƩtodo [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
+
+```py
+>>> import tensorflow as tf
+
+>>> model.compile(optimizer=optimizer)
+```
+
+Chame o mƩtodo [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) para executar o fine-tuning do modelo:
+
+```py
+>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3)
+```
+
+
+
+
+
+Para obter um exemplo mais aprofundado de como executar o fine-tuning de um modelo para classificaĆ§Ć£o de texto, dĆŖ uma olhada nesse [notebook utilizando PyTorch](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb) ou nesse [notebook utilizando TensorFlow](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).
+
+
\ No newline at end of file
diff --git a/docs/source/pt/tasks/token_classification.mdx b/docs/source/pt/tasks/token_classification.mdx
new file mode 100644
index 00000000000000..780080a60dd325
--- /dev/null
+++ b/docs/source/pt/tasks/token_classification.mdx
@@ -0,0 +1,268 @@
+
+
+# ClassificaĆ§Ć£o de tokens
+
+
+
+A classificaĆ§Ć£o de tokens atribui um rĆ³tulo a tokens individuais em uma frase. Uma das tarefas de classificaĆ§Ć£o de tokens mais comuns Ć© o Reconhecimento de Entidade Nomeada, tambĆ©m chamada de NER (sigla em inglĆŖs para Named Entity Recognition). O NER tenta encontrar um rĆ³tulo para cada entidade em uma frase, como uma pessoa, local ou organizaĆ§Ć£o.
+
+Este guia mostrarĆ” como realizar o fine-tuning do [DistilBERT](https://huggingface.co/distilbert-base-uncased) no conjunto de dados [WNUT 17](https://huggingface.co/datasets/wnut_17) para detectar novas entidades.
+
+
+
+Consulte a [pĆ”gina de tarefas de classificaĆ§Ć£o de tokens](https://huggingface.co/tasks/token-classification) para obter mais informaƧƵes sobre outras formas de classificaĆ§Ć£o de tokens e seus modelos, conjuntos de dados e mĆ©tricas associadas.
+
+
+
+## Carregando o conjunto de dados WNUT 17
+
+Carregue o conjunto de dados WNUT 17 da biblioteca š¤ Datasets:
+
+```py
+>>> from datasets import load_dataset
+
+>>> wnut = load_dataset("wnut_17")
+```
+
+E dĆŖ uma olhada em um exemplo:
+
+```py
+>>> wnut["train"][0]
+{'id': '0',
+ 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
+ 'tokens': ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']
+}
+```
+
+Cada nĆŗmero em `ner_tags` representa uma entidade. Converta o nĆŗmero em um rĆ³tulo para obter mais informaƧƵes:
+
+```py
+>>> label_list = wnut["train"].features[f"ner_tags"].feature.names
+>>> label_list
+[
+ "O",
+ "B-corporation",
+ "I-corporation",
+ "B-creative-work",
+ "I-creative-work",
+ "B-group",
+ "I-group",
+ "B-location",
+ "I-location",
+ "B-person",
+ "I-person",
+ "B-product",
+ "I-product",
+]
+```
+
+O `ner_tag` descreve uma entidade, como uma organizaĆ§Ć£o, local ou pessoa. A letra que prefixa cada `ner_tag` indica a posiĆ§Ć£o do token da entidade:
+
+- `B-` indica o inĆcio de uma entidade.
+- `I-` indica que um token estĆ” contido dentro da mesma entidade (por exemplo, o token `State` pode fazer parte de uma entidade como `Empire State Building`).
+- `0` indica que o token nĆ£o corresponde a nenhuma entidade.
+
+## PrƩ-processamento
+
+
+
+Carregue o tokenizer do DistilBERT para processar os `tokens`:
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
+```
+
+Como a entrada jĆ” foi dividida em palavras, defina `is_split_into_words=True` para tokenizar as palavras em subpalavras:
+
+```py
+>>> tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
+>>> tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
+>>> tokens
+['[CLS]', '@', 'paul', '##walk', 'it', "'", 's', 'the', 'view', 'from', 'where', 'i', "'", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']
+```
+
+Ao adicionar os tokens especiais `[CLS]` e `[SEP]` e a tokenizaĆ§Ć£o de subpalavras uma incompatibilidade Ć© gerada entre a entrada e os rĆ³tulos. Uma Ćŗnica palavra correspondente a um Ćŗnico rĆ³tulo pode ser dividida em duas subpalavras. VocĆŖ precisarĆ” realinhar os tokens e os rĆ³tulos da seguinte forma:
+
+1. Mapeie todos os tokens para a palavra correspondente com o mƩtodo [`word_ids`](https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Encoding.word_ids).
+2. Atribuindo o rĆ³tulo `-100` aos tokens especiais `[CLS]` e `[SEP]` para que a funĆ§Ć£o de loss do PyTorch ignore eles.
+3. Rotular apenas o primeiro token de uma determinada palavra. Atribuindo `-100` a outros subtokens da mesma palavra.
+
+Aqui estĆ” como vocĆŖ pode criar uma funĆ§Ć£o para realinhar os tokens e rĆ³tulos e truncar sequĆŖncias para nĆ£o serem maiores que o comprimento mĆ”ximo de entrada do DistilBERT:
+
+```py
+>>> def tokenize_and_align_labels(examples):
+... tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
+
+... labels = []
+... for i, label in enumerate(examples[f"ner_tags"]):
+... word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word.
+... previous_word_idx = None
+... label_ids = []
+... for word_idx in word_ids: # Set the special tokens to -100.
+... if word_idx is None:
+... label_ids.append(-100)
+... elif word_idx != previous_word_idx: # Only label the first token of a given word.
+... label_ids.append(label[word_idx])
+... else:
+... label_ids.append(-100)
+... previous_word_idx = word_idx
+... labels.append(label_ids)
+
+... tokenized_inputs["labels"] = labels
+... return tokenized_inputs
+```
+
+Use a funĆ§Ć£o [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) do š¤ Datasets para tokenizar e alinhar os rĆ³tulos em todo o conjunto de dados. VocĆŖ pode acelerar a funĆ§Ć£o `map` configurando `batched=True` para processar vĆ”rios elementos do conjunto de dados de uma sĆ³ vez:
+
+```py
+>>> tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)
+```
+
+Use o [`DataCollatorForTokenClassification`] para criar um batch de exemplos. Ele tambĆ©m *preencherĆ” dinamicamente* seu texto e rĆ³tulos para o comprimento do elemento mais longo em seu batch, para que tenham um comprimento uniforme. Embora seja possĆvel preencher seu texto na funĆ§Ć£o `tokenizer` configurando `padding=True`, o preenchimento dinĆ¢mico Ć© mais eficiente.
+
+
+
+```py
+>>> from transformers import DataCollatorForTokenClassification
+
+>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
+```
+
+
+```py
+>>> from transformers import DataCollatorForTokenClassification
+
+>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="tf")
+```
+
+
+
+## Treinamento
+
+
+
+Carregue o DistilBERT com o [`AutoModelForTokenClassification`] junto com o nĆŗmero de rĆ³tulos esperados:
+
+```py
+>>> from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
+
+>>> model = AutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=14)
+```
+
+
+
+Se vocĆŖ nĆ£o estiver familiarizado com o fine-tuning de um modelo com o [`Trainer`], dĆŖ uma olhada no tutorial bĆ”sico [aqui](../training#finetune-with-trainer)!
+
+
+
+Nesse ponto, restam apenas trĆŖs passos:
+
+1. Definir seus hiperparĆ¢metros de treinamento em [`TrainingArguments`].
+2. Passar os argumentos de treinamento para o [`Trainer`] junto com o modelo, conjunto de dados, tokenizador e o data collator.
+3. Chamar a funĆ§Ć£o [`~Trainer.train`] para executar o fine-tuning do seu modelo.
+
+```py
+>>> training_args = TrainingArguments(
+... output_dir="./results",
+... evaluation_strategy="epoch",
+... learning_rate=2e-5,
+... per_device_train_batch_size=16,
+... per_device_eval_batch_size=16,
+... num_train_epochs=3,
+... weight_decay=0.01,
+... )
+
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=tokenized_wnut["train"],
+... eval_dataset=tokenized_wnut["test"],
+... tokenizer=tokenizer,
+... data_collator=data_collator,
+... )
+
+>>> trainer.train()
+```
+
+
+Para executar o fine-tuning de um modelo no TensorFlow, comece convertendo seu conjunto de dados para o formato `tf.data.Dataset` com [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Nessa execuĆ§Ć£o vocĆŖ deverĆ” especificar as entradas e rĆ³tulos (no parĆ¢metro `columns`), se deseja embaralhar o conjunto de dados, o tamanho do batch e o data collator:
+
+```py
+>>> tf_train_set = tokenized_wnut["train"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... shuffle=True,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+
+>>> tf_validation_set = tokenized_wnut["validation"].to_tf_dataset(
+... columns=["attention_mask", "input_ids", "labels"],
+... shuffle=False,
+... batch_size=16,
+... collate_fn=data_collator,
+... )
+```
+
+
+
+Se vocĆŖ nĆ£o estiver familiarizado com o fine-tuning de um modelo com o Keras, dĆŖ uma olhada no tutorial bĆ”sico [aqui](training#finetune-with-keras)!
+
+
+
+Configure o otimizador e alguns hiperparĆ¢metros de treinamento:
+
+```py
+>>> from transformers import create_optimizer
+
+>>> batch_size = 16
+>>> num_train_epochs = 3
+>>> num_train_steps = (len(tokenized_wnut["train"]) // batch_size) * num_train_epochs
+>>> optimizer, lr_schedule = create_optimizer(
+... init_lr=2e-5,
+... num_train_steps=num_train_steps,
+... weight_decay_rate=0.01,
+... num_warmup_steps=0,
+... )
+```
+
+Carregue o DistilBERT com o [`TFAutoModelForTokenClassification`] junto com o nĆŗmero de rĆ³tulos esperados:
+
+```py
+>>> from transformers import TFAutoModelForTokenClassification
+
+>>> model = TFAutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
+```
+
+Configure o modelo para treinamento com o mƩtodo [`compile`](https://keras.io/api/models/model_training_apis/#compile-method):
+
+```py
+>>> import tensorflow as tf
+
+>>> model.compile(optimizer=optimizer)
+```
+
+Chame o mƩtodo [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) para executar o fine-tuning do modelo:
+
+```py
+>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3)
+```
+
+
+
+
+
+Para obter um exemplo mais aprofundado de como executar o fine-tuning de um modelo para classificaĆ§Ć£o de tokens, dĆŖ uma olhada nesse [notebook utilizando PyTorch](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb) ou nesse [notebook utilizando TensorFlow](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb).
+
+
\ No newline at end of file
diff --git a/docs/source/pt/training.mdx b/docs/source/pt/training.mdx
new file mode 100644
index 00000000000000..3d3697d79190de
--- /dev/null
+++ b/docs/source/pt/training.mdx
@@ -0,0 +1,413 @@
+
+
+# Fine-tuning de um modelo prƩ-treinado
+
+[[open-in-colab]]
+
+O uso de um modelo prĆ©-treinado tem importantes vantagens. ReduĆ§Ć£o do custo computacional, a pegada de carbono, e te
+permite utilizar modelos de Ćŗltima geraĆ§Ć£o sem ter que treinar um novo desde o inĆcio.
+O š¤ Transformers proporciona acesso a milhares de modelos prĆ©-treinados numa ampla gama de tarefas.
+Quando utilizar um modelo prĆ©-treinado, treine-o com um dataset especĆfico para a sua tarefa.
+Isto Ʃ chamado de fine-tuning, uma tƩcnica de treinamento incrivelmente poderosa. Neste tutorial faremos o fine-tuning
+a um modelo prƩ-treinado com um framework de Deep Learning de sua escolha:
+
+* Fine-tuning de um modelo prĆ©-treinado com o š¤ Transformers [`Trainer`].
+* Fine-tuning de um modelo prƩ-treinado no TensorFlow com o Keras.
+* Fine-tuning de um modelo prƩ-treinado em PyTorch nativo.
+
+
+
+## Preparando um dataset
+
+
+
+Antes de aplicar o fine-tuning a um modelo prƩ-treinado, baixe um dataset e prepare-o para o treinamento.
+O tutorial anterior ensinarĆ” a processar os dados para o treinamento, e entĆ£o poderĆ” ter a oportunidade de testar
+esse novo conhecimento em algo prƔtico.
+
+Comece carregando o dataset [Yelp Reviews](https://huggingface.co/datasets/yelp_review_full):
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("yelp_review_full")
+>>> dataset[100]
+{'label': 0,
+ 'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}
+```
+
+Como jƔ sabe, Ʃ necessƔrio ter um tokenizador para processar o texto e incluir uma estratƩgia de padding e truncamento,
+para manejar qualquer tamanho varĆavel de sequĆŖncia. Para processar o seu dataset em apenas um passo, utilize o mĆ©todo de
+š¤ Datasets [`map`](https://huggingface.co/docs/datasets/process.html#map) para aplicar uma funĆ§Ć£o de preprocessamento sobre
+todo o dataset.
+
+```py
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
+
+
+>>> def tokenize_function(examples):
+... return tokenizer(examples["text"], padding="max_length", truncation=True)
+
+
+>>> tokenized_datasets = dataset.map(tokenize_function, batched=True)
+```
+
+Se desejar, Ć© possĆvel criar um subconjunto menor do dataset completo para aplicar o fine-tuning e assim reduzir o tempo necessĆ”rio.
+
+```py
+>>> small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
+>>> small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
+```
+
+
+
+## Fine-tuning com o `Trainer`
+
+
+
+O š¤ Transformers proporciona uma classe [`Trainer`] otimizada para o treinamento de modelos de š¤ Transformers,
+facilitando os primeiros passos do treinamento sem a necessidade de escrever manualmente seu prĆ³prio ciclo.
+A API do [`Trainer`] suporta um grande conjunto de opƧƵes de treinamento e funcionalidades, como o logging,
+o gradient accumulation e o mixed precision.
+
+Comece carregando seu modelo e especifique o nĆŗmero de labels de previsĆ£o.
+A partir do [Card Dataset](https://huggingface.co/datasets/yelp_review_full#data-fields) do Yelp Reveiw, que ja
+sabemos ter 5 labels usamos o seguinte cĆ³digo:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification
+
+>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
+```
+
+
+
+ VocĆŖ verĆ” um alerta sobre alguns pesos prĆ©-treinados que nĆ£o estĆ£o sendo utilizados e que alguns pesos estĆ£o
+ sendo inicializados aleatoriamente. NĆ£o se preocupe, essa mensagem Ć© completamente normal.
+ O header/cabeƧƔrio prĆ©-treinado do modelo BERT Ć© descartado e substitui-se por um header de classificaĆ§Ć£o
+ inicializado aleatoriamente. Assim, pode aplicar o fine-tuning a este novo header do modelo em sua tarefa
+ de classificaĆ§Ć£o de sequĆŖncias fazendo um transfer learning do modelo prĆ©-treinado.
+
+
+
+### HiperparĆ¢metros de treinamento
+
+Em seguida, crie uma classe [`TrainingArguments`] que contenha todos os hiperparĆ¢metros que possam ser ajustados, assim
+como os indicadores para ativar as diferentes opƧƵes de treinamento. Para este tutorial, vocĆŖ pode comeƧar o treinamento
+usando os [hiperparĆ”metros](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) padrĆ£o,
+porĆ©m, sinta-se livre para experimentar com eles e encontrar uma configuraĆ§Ć£o Ć³tima.
+
+Especifique onde salvar os checkpoints do treinamento:
+
+```py
+>>> from transformers import TrainingArguments
+
+>>> training_args = TrainingArguments(output_dir="test_trainer")
+```
+
+### MĆ©tricas
+
+O [`Trainer`] nĆ£o avalia automaticamente o rendimento do modelo durante o treinamento. SerĆ” necessĆ”rio passar ao
+[`Trainer`] uma funĆ§Ć£o para calcular e fazer um diagnĆ³stico sobre as mĆ©tricas. A biblioteca š¤ Datasets proporciona
+uma funĆ§Ć£o de [`accuracy`](https://huggingface.co/metrics/accuracy) simples que pode ser carregada com a funĆ§Ć£o
+`load_metric` (ver este [tutorial](https://huggingface.co/docs/datasets/metrics.html) para mais informaƧƵes):
+
+```py
+>>> import numpy as np
+>>> from datasets import load_metric
+
+>>> metric = load_metric("accuracy")
+```
+
+Defina a funĆ§Ć£o `compute` dentro de `metric` para calcular a precisĆ£o de suas prediƧƵes.
+Antes de passar suas prediƧƵes ao `compute`, Ć© necessĆ”rio converter as prediƧƵes Ć logits (lembre-se que
+todos os modelos de š¤ Transformers retornam logits).
+
+```py
+>>> def compute_metrics(eval_pred):
+... logits, labels = eval_pred
+... predictions = np.argmax(logits, axis=-1)
+... return metric.compute(predictions=predictions, references=labels)
+```
+
+Se quiser controlar suas mĆ©tricas de avaliaĆ§Ć£o durante o fine-tuning, especifique o parĆ¢metro `evaluation_strategy`
+em seus argumentos de treinamento para que o modelo leve em conta a mĆ©trica de avaliaĆ§Ć£o ao final de cada Ć©poca:
+
+```py
+>>> from transformers import TrainingArguments
+
+>>> training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
+```
+
+### Trainer
+
+Crie um objeto [`Trainer`] com seu modelo, argumentos de treinamento, conjuntos de dados de treinamento e de teste, e sua funĆ§Ć£o de avaliaĆ§Ć£o:
+
+```py
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=small_train_dataset,
+... eval_dataset=small_eval_dataset,
+... compute_metrics=compute_metrics,
+... )
+```
+
+Em seguida, aplique o fine-tuning a seu modelo chamado [`~transformers.Trainer.train`]:
+
+```py
+>>> trainer.train()
+```
+
+
+
+## Fine-tuning com Keras
+
+
+
+Os modelos de š¤ Transformers tambĆ©m permitem realizar o treinamento com o TensorFlow com a API do Keras.
+Contudo, serƔ necessƔrio fazer algumas mudanƧas antes de realizar o fine-tuning.
+
+### ConversĆ£o do dataset ao formato do TensorFlow
+
+O [`DefaultDataCollator`] junta os tensores em um batch para que o modelo possa ser treinado em cima deles.
+Assegure-se de especificar os `return_tensors` para retornar os tensores do TensorFlow:
+
+```py
+>>> from transformers import DefaultDataCollator
+
+>>> data_collator = DefaultDataCollator(return_tensors="tf")
+```
+
+
+
+ O [`Trainer`] utiliza [`DataCollatorWithPadding`] por padrĆ£o, entĆ£o vocĆŖ nĆ£o precisa especificar explicitamente um
+ colador de dados (data collator).
+
+
+
+Em seguida, converta os datasets tokenizados em datasets do TensorFlow com o mƩtodo
+[`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset).
+Especifique suas entradas em `columns` e seu rĆ³tulo em `label_cols`:
+
+```py
+>>> tf_train_dataset = small_train_dataset.to_tf_dataset(
+... columns=["attention_mask", "input_ids", "token_type_ids"],
+... label_cols=["labels"],
+... shuffle=True,
+... collate_fn=data_collator,
+... batch_size=8,
+... )
+
+>>> tf_validation_dataset = small_eval_dataset.to_tf_dataset(
+... columns=["attention_mask", "input_ids", "token_type_ids"],
+... label_cols=["labels"],
+... shuffle=False,
+... collate_fn=data_collator,
+... batch_size=8,
+... )
+```
+
+### CompilaĆ§Ć£o e ajustes
+
+Carregue um modelo do TensorFlow com o nĆŗmero esperado de rĆ³tulos:
+
+```py
+>>> import tensorflow as tf
+>>> from transformers import TFAutoModelForSequenceClassification
+
+>>> model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
+```
+
+A seguir, compile e ajuste o fine-tuning a seu modelo com [`fit`](https://keras.io/api/models/model_training_apis/) como
+faria com qualquer outro modelo do Keras:
+
+```py
+>>> model.compile(
+... optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
+... loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+... metrics=tf.metrics.SparseCategoricalAccuracy(),
+... )
+
+>>> model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)
+```
+
+
+
+## Fine-tune em PyTorch nativo
+
+
+
+O [`Trainer`] se encarrega do ciclo de treinamento e permite aplicar o fine-tuning a um modelo em uma linha de cĆ³digo apenas.
+Para os usuĆ”rios que preferirem escrever seu prĆ³prio ciclo de treinamento, tambĆ©m Ć© possĆvel aplicar o fine-tuning a um
+modelo de š¤ Transformers em PyTorch nativo.
+
+Neste momento, talvez ocorra a necessidade de reinicar seu notebook ou executar a seguinte linha de cĆ³digo para liberar
+memĆ³ria:
+
+```py
+del model
+del pytorch_model
+del trainer
+torch.cuda.empty_cache()
+```
+
+Em sequĆŖncia, faremos um post-processing manual do `tokenized_dataset` e assim preparĆ”-lo para o treinamento.
+
+1. Apague a coluna de `text` porque o modelo nĆ£o aceita texto cru como entrada:
+
+ ```py
+ >>> tokenized_datasets = tokenized_datasets.remove_columns(["text"])
+ ```
+
+2. Troque o nome da coluna `label` para `labels`, pois o modelo espera um argumento de mesmo nome:
+
+ ```py
+ >>> tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
+ ```
+
+3. Defina o formato do dataset para retornar tensores do PyTorch no lugar de listas:
+
+ ```py
+ >>> tokenized_datasets.set_format("torch")
+ ```
+
+Em sequĆŖncia, crie um subconjunto menor do dataset, como foi mostrado anteriormente, para acelerĆ”-lo o fine-tuning.
+
+```py
+>>> small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
+>>> small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
+```
+
+### DataLoader
+
+Crie um `DataLoader` para seus datasets de treinamento e de teste para poder iterar sobre batches de dados:
+
+```py
+>>> from torch.utils.data import DataLoader
+
+>>> train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
+>>> eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)
+```
+
+Carregue seu modelo com o nĆŗmero de labels esperados:
+
+```py
+>>> from transformers import AutoModelForSequenceClassification
+
+>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
+```
+
+### OtimizaĆ§Ć£o e configuraĆ§Ć£o do Learning Rate
+
+Crie um otimizador e um learning rate para aplicar o fine-tuning ao modelo.
+Iremos utilizar o otimizador [`AdamW`](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) do PyTorch:
+
+```py
+>>> from torch.optim import AdamW
+
+>>> optimizer = AdamW(model.parameters(), lr=5e-5)
+```
+
+Defina o learning rate do [`Trainer`]:
+
+```py
+>>> from transformers import get_scheduler
+
+>>> num_epochs = 3
+>>> num_training_steps = num_epochs * len(train_dataloader)
+>>> lr_scheduler = get_scheduler(
+... name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
+... )
+```
+
+Por Ćŗltimo, especifique o `device` do ambiente para utilizar uma GPU se tiver acesso Ć alguma. Caso contrĆ”rio, o treinamento
+em uma CPU pode acabar levando vƔrias horas em vez de minutos.
+
+```py
+>>> import torch
+
+>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+>>> model.to(device)
+```
+
+
+
+ Se necessĆ”rio, vocĆŖ pode obter o acesso gratuito a uma GPU na nĆŗvem por meio de um notebook no
+ [Colaboratory](https://colab.research.google.com/) ou [SageMaker StudioLab](https://studiolab.sagemaker.aws/)
+ se nĆ£o tiver esse recurso de forma local.
+
+
+
+Perfeito, agora estamos prontos para comeƧar o treinamento! š„³
+Genial, Ā”ahora estamos listos entrenar! š„³
+
+### Ciclo de treinamento
+
+Para visualizar melhor o processo de treinamento, utilize a biblioteca [tqdm](https://tqdm.github.io/) para adicionar
+uma barra de progresso sobre o nĆŗmero de passos percorridos no treinamento atual:
+
+```py
+>>> from tqdm.auto import tqdm
+
+>>> progress_bar = tqdm(range(num_training_steps))
+
+>>> model.train()
+>>> for epoch in range(num_epochs):
+... for batch in train_dataloader:
+... batch = {k: v.to(device) for k, v in batch.items()}
+... outputs = model(**batch)
+... loss = outputs.loss
+... loss.backward()
+
+... optimizer.step()
+... lr_scheduler.step()
+... optimizer.zero_grad()
+... progress_bar.update(1)
+```
+
+### MĆ©tricas
+
+Da mesma forma que Ć© necessĆ”rio adicionar uma funĆ§Ć£o de avaliaĆ§Ć£o ao [`Trainer`], Ć© necessĆ”rio fazer o mesmo quando
+escrevendo o prĆ³prio ciclo de treinamento. Contudo, em vez de calcular e retornar a mĆ©trica final de cada Ć©poca,
+vocĆŖ deverĆ” adicionar todos os batches com [`add_batch`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=add_batch#datasets.Metric.add_batch)
+e calcular a mƩtrica apenas no final.
+
+```py
+>>> metric = load_metric("accuracy")
+>>> model.eval()
+>>> for batch in eval_dataloader:
+... batch = {k: v.to(device) for k, v in batch.items()}
+... with torch.no_grad():
+... outputs = model(**batch)
+
+... logits = outputs.logits
+... predictions = torch.argmax(logits, dim=-1)
+... metric.add_batch(predictions=predictions, references=batch["labels"])
+
+>>> metric.compute()
+```
+
+
+
+## Recursos adicionais
+
+Para mais exemplos de fine-tuning acesse:
+
+- [š¤ Transformers Examples](https://github.com/huggingface/transformers/tree/main/examples) inclui scripts
+para treinas tarefas comuns de NLP em PyTorch e TensorFlow.
+
+- [š¤ Transformers Notebooks](notebooks) contĆ©m vĆ”rios notebooks sobre como aplicar o fine-tuning a um modelo
+para tarefas especĆficas no PyTorch e TensorFlow.
diff --git a/examples/flax/image-captioning/create_model_from_encoder_decoder_models.py b/examples/flax/image-captioning/create_model_from_encoder_decoder_models.py
index 953aa136e97a61..ab2fb8568d5205 100644
--- a/examples/flax/image-captioning/create_model_from_encoder_decoder_models.py
+++ b/examples/flax/image-captioning/create_model_from_encoder_decoder_models.py
@@ -42,14 +42,18 @@ class ModelArguments:
)
encoder_model_name_or_path: str = field(
metadata={
- "help": "The encoder model checkpoint for weights initialization."
- "Don't set if you want to train an encoder model from scratch."
+ "help": (
+ "The encoder model checkpoint for weights initialization."
+ "Don't set if you want to train an encoder model from scratch."
+ )
},
)
decoder_model_name_or_path: str = field(
metadata={
- "help": "The decoder model checkpoint for weights initialization."
- "Don't set if you want to train a decoder model from scratch."
+ "help": (
+ "The decoder model checkpoint for weights initialization."
+ "Don't set if you want to train a decoder model from scratch."
+ )
},
)
encoder_config_name: Optional[str] = field(
diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py
index b1c9012777ac5b..149d3abff5e300 100644
--- a/examples/flax/image-captioning/run_image_captioning_flax.py
+++ b/examples/flax/image-captioning/run_image_captioning_flax.py
@@ -52,7 +52,7 @@
HfArgumentParser,
is_tensorboard_available,
)
-from transformers.utils import get_full_repo_name, is_offline_mode
+from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -175,14 +175,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -222,38 +227,48 @@ class DataTrainingArguments:
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
- "during evaluation."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
+ "during evaluation."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -266,8 +281,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
- "which is used during evaluation."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
+ "which is used during evaluation."
+ )
},
)
overwrite_cache: bool = field(
@@ -371,6 +388,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_captioning", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -623,7 +644,7 @@ def preprocess_fn(examples, max_target_length, check_image=True):
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
if training_args.block_size % train_batch_size > 0 or training_args.block_size % eval_batch_size > 0:
raise ValueError(
- f"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
+ "`training_args.block_size` needs to be a multiple of the global train/eval batch size."
f"Got {training_args.block_size}, {train_batch_size} and {eval_batch_size} respectively instead."
)
@@ -1136,7 +1157,7 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
)
# train
- for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
+ for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
cur_step += 1
batch = next(train_batches)
@@ -1150,7 +1171,10 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
_train_metric = unreplicate(train_metric)
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
+ desc = (
+ f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
+ f" Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
+ )
epochs.desc = desc
epochs.write(desc)
diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py
index afb6d75b38570e..1bf088df29c495 100755
--- a/examples/flax/language-modeling/run_clm_flax.py
+++ b/examples/flax/language-modeling/run_clm_flax.py
@@ -58,7 +58,7 @@
set_seed,
)
from transformers.testing_utils import CaptureLogger
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -138,8 +138,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -162,14 +163,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -194,15 +200,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -217,9 +227,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -316,6 +328,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -505,7 +521,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
@@ -735,7 +752,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
@@ -762,7 +780,10 @@ def eval_step(params, batch):
eval_metrics["perplexity"] = float("inf")
# Print metrics and update progress bar
- desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
+ desc = (
+ f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
+ f" {eval_metrics['perplexity']})"
+ )
epochs.write(desc)
epochs.desc = desc
diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py
index 6ea0f6e1564f51..3538ba268334a3 100755
--- a/examples/flax/language-modeling/run_mlm_flax.py
+++ b/examples/flax/language-modeling/run_mlm_flax.py
@@ -58,7 +58,7 @@
is_tensorboard_available,
set_seed,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@@ -136,8 +136,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -160,14 +161,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -209,8 +215,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -223,8 +231,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
line_by_line: bool = field(
@@ -355,6 +365,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -764,7 +778,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py
index 368ecf0e61c05f..48a58b60c0a821 100755
--- a/examples/flax/language-modeling/run_t5_mlm_flax.py
+++ b/examples/flax/language-modeling/run_t5_mlm_flax.py
@@ -57,7 +57,7 @@
set_seed,
)
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@@ -135,8 +135,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -159,14 +160,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -208,7 +214,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization and masking. Sequences longer than this"
+ " will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -337,12 +346,14 @@ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarra
if batch["input_ids"].shape[-1] != self.input_length:
raise ValueError(
- f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
+ f" should be {self.target_length}."
)
if batch["labels"].shape[-1] != self.target_length:
raise ValueError(
- f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
+ f" {self.target_length}."
)
# to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
@@ -487,6 +498,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -884,7 +899,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py
index ac4ec706bfcf79..5c4fe750a993e5 100644
--- a/examples/flax/question-answering/run_qa.py
+++ b/examples/flax/question-answering/run_qa.py
@@ -53,14 +53,14 @@
PreTrainedTokenizerFast,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from utils_qa import postprocess_qa_predictions
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
@@ -157,14 +157,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -200,37 +205,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -239,9 +253,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -255,8 +271,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -406,6 +424,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args, framework="flax")
# endregion
# region Logging
@@ -498,9 +520,9 @@ def main():
# region Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# endregion
@@ -928,7 +950,8 @@ def eval_step(state, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py
index 3ebff73b98ff20..0de02fe950f901 100644
--- a/examples/flax/summarization/run_summarization_flax.py
+++ b/examples/flax/summarization/run_summarization_flax.py
@@ -54,7 +54,7 @@
HfArgumentParser,
is_tensorboard_available,
)
-from transformers.utils import get_full_repo_name, is_offline_mode
+from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -149,8 +149,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -173,14 +174,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -217,45 +223,57 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
- "during evaluation."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
+ "during evaluation."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -271,8 +289,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
- "which is used during evaluation."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
+ "which is used during evaluation."
+ )
},
)
overwrite_cache: bool = field(
@@ -379,6 +399,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -831,7 +855,8 @@ def generate_step(params, batch):
train_metric = unreplicate(train_metric)
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py
index 23144069d7dd81..d32f70a4c165dc 100755
--- a/examples/flax/text-classification/run_flax_glue.py
+++ b/examples/flax/text-classification/run_flax_glue.py
@@ -48,12 +48,12 @@
TrainingArguments,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
@@ -103,8 +103,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -148,29 +150,37 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
@@ -298,6 +308,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args, framework="flax")
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -585,7 +599,8 @@ def eval_step(state, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py
index a0e01b080275cb..2d6f37f9350ec4 100644
--- a/examples/flax/token-classification/run_flax_ner.py
+++ b/examples/flax/token-classification/run_flax_ner.py
@@ -47,13 +47,13 @@
HfArgumentParser,
is_tensorboard_available,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
@@ -150,8 +150,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -196,36 +198,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -354,6 +366,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args, framework="flax")
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -693,7 +709,8 @@ def compute_metrics():
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -744,7 +761,8 @@ def compute_metrics():
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
else:
logger.info(
- f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
+ f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:"
+ f" {eval_metrics['accuracy']})"
)
if has_tensorboard and jax.process_index() == 0:
diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py
index 0dc7b2f9574291..d8ddd13cefcd38 100644
--- a/examples/flax/vision/run_image_classification.py
+++ b/examples/flax/vision/run_image_classification.py
@@ -53,7 +53,7 @@
is_tensorboard_available,
set_seed,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
logger = logging.getLogger(__name__)
@@ -134,8 +134,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -151,14 +152,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -179,15 +185,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -246,6 +256,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification", model_args, data_args, framework="flax")
+
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
@@ -509,7 +523,8 @@ def eval_step(params, batch):
train_step_progress_bar.close()
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
diff --git a/examples/legacy/multiple_choice/run_multiple_choice.py b/examples/legacy/multiple_choice/run_multiple_choice.py
index aeb9b9dc434ac0..d8007da6cb676c 100644
--- a/examples/legacy/multiple_choice/run_multiple_choice.py
+++ b/examples/legacy/multiple_choice/run_multiple_choice.py
@@ -78,8 +78,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -102,7 +104,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/multiple_choice/utils_multiple_choice.py b/examples/legacy/multiple_choice/utils_multiple_choice.py
index 2b6b5cc18322ba..3dbc3689cc4893 100644
--- a/examples/legacy/multiple_choice/utils_multiple_choice.py
+++ b/examples/legacy/multiple_choice/utils_multiple_choice.py
@@ -182,7 +182,7 @@ def __init__(
)
def gen():
- for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
+ for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
@@ -297,7 +297,7 @@ def _read_txt(self, input_dir):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
- for (_, data_raw) in enumerate(lines):
+ for _, data_raw in enumerate(lines):
race_id = "%s-%s" % (set_type, data_raw["race_id"])
article = data_raw["article"]
for i in range(len(data_raw["answers"])):
@@ -518,7 +518,7 @@ def convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
+ for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
choices_inputs = []
diff --git a/examples/legacy/pytorch-lightning/lightning_base.py b/examples/legacy/pytorch-lightning/lightning_base.py
index b7f53076e3bc31..b3104a25a8b129 100644
--- a/examples/legacy/pytorch-lightning/lightning_base.py
+++ b/examples/legacy/pytorch-lightning/lightning_base.py
@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/legacy/pytorch-lightning/run_glue.py b/examples/legacy/pytorch-lightning/run_glue.py
index abb06bf526bbb7..63b58bcf413c26 100644
--- a/examples/legacy/pytorch-lightning/run_glue.py
+++ b/examples/legacy/pytorch-lightning/run_glue.py
@@ -148,8 +148,10 @@ def add_model_specific_args(parser, root_dir):
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
diff --git a/examples/legacy/pytorch-lightning/run_ner.py b/examples/legacy/pytorch-lightning/run_ner.py
index 1066c6fed48cc9..b1bdd125c22eb8 100644
--- a/examples/legacy/pytorch-lightning/run_ner.py
+++ b/examples/legacy/pytorch-lightning/run_ner.py
@@ -173,8 +173,10 @@ def add_model_specific_args(parser, root_dir):
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
diff --git a/examples/legacy/question-answering/run_squad.py b/examples/legacy/question-answering/run_squad.py
index fbf2ebd6351abb..674e7a9accbf3a 100644
--- a/examples/legacy/question-answering/run_squad.py
+++ b/examples/legacy/question-answering/run_squad.py
@@ -551,8 +551,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -564,8 +566,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -610,20 +614,27 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument(
"--lang_id",
default=0,
type=int,
- help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
+ help=(
+ "language id of input for language-specific xlm models (see"
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
+ ),
)
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
@@ -652,8 +663,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/legacy/question-answering/run_squad_trainer.py b/examples/legacy/question-answering/run_squad_trainer.py
index 7089326372ea54..314b140e828c59 100644
--- a/examples/legacy/question-answering/run_squad_trainer.py
+++ b/examples/legacy/question-answering/run_squad_trainer.py
@@ -84,7 +84,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/run_language_modeling.py b/examples/legacy/run_language_modeling.py
index 12b62f5d816cea..59490f710e1338 100755
--- a/examples/legacy/run_language_modeling.py
+++ b/examples/legacy/run_language_modeling.py
@@ -68,7 +68,10 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization. Leave None if you want to train a model from"
+ " scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -99,8 +102,10 @@ class DataTrainingArguments:
train_data_files: Optional[str] = field(
default=None,
metadata={
- "help": "The input training data files (multiple files in glob format). "
- "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
+ "help": (
+ "The input training data files (multiple files in glob format). "
+ "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
+ )
},
)
eval_data_file: Optional[str] = field(
@@ -130,7 +135,10 @@ class DataTrainingArguments:
plm_probability: float = field(
default=1 / 6,
metadata={
- "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
+ "help": (
+ "Ratio of length of a span of masked tokens to surrounding context length for permutation language"
+ " modeling."
+ )
},
)
max_span_length: int = field(
@@ -140,9 +148,11 @@ class DataTrainingArguments:
block_size: int = field(
default=-1,
metadata={
- "help": "Optional input sequence length after tokenization."
- "The training dataset will be truncated in block of this size for training."
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization."
+ "The training dataset will be truncated in block of this size for training."
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -206,7 +216,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
@@ -253,8 +264,8 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
raise ValueError(
- "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
- "and load it from here, using --tokenizer_name"
+ "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another"
+ " script, save it,and load it from here, using --tokenizer_name"
)
if model_args.model_name_or_path:
diff --git a/examples/legacy/run_openai_gpt.py b/examples/legacy/run_openai_gpt.py
index 2af3e267d2e78e..1f02570f8f514a 100755
--- a/examples/legacy/run_openai_gpt.py
+++ b/examples/legacy/run_openai_gpt.py
@@ -126,15 +126,15 @@ def main():
"--max_steps",
default=-1,
type=int,
- help="If > 0: set total number of training \
- steps to perform. Override num_train_epochs.",
+ help=(
+ "If > 0: set total number of training steps to perform. Override num_train_epochs."
+ ),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
- help="Number of updates steps to accumulate before\
- performing a backward/update pass.",
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
diff --git a/examples/legacy/run_swag.py b/examples/legacy/run_swag.py
index e7760410892f9e..5cac1567243c3e 100755
--- a/examples/legacy/run_swag.py
+++ b/examples/legacy/run_swag.py
@@ -516,8 +516,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -576,8 +578,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/legacy/seq2seq/finetune_trainer.py b/examples/legacy/seq2seq/finetune_trainer.py
index 3efc8f90f25b70..f174f7fb5018f9 100755
--- a/examples/legacy/seq2seq/finetune_trainer.py
+++ b/examples/legacy/seq2seq/finetune_trainer.py
@@ -90,31 +90,39 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=142,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. "
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. "
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
- "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for test target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
diff --git a/examples/legacy/seq2seq/old_test_calculate_rouge.py b/examples/legacy/seq2seq/old_test_calculate_rouge.py
index bd1dd57a27252b..17b87cb481a650 100644
--- a/examples/legacy/seq2seq/old_test_calculate_rouge.py
+++ b/examples/legacy/seq2seq/old_test_calculate_rouge.py
@@ -22,15 +22,30 @@
PRED = [
- 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
- "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
- "Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
+ 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
+ ' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
+ " depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
+ "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
+ " accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
+ " founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
+ " body.",
+ "Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
+ " state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
+ " world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
+ " punishment.",
]
TGT = [
- 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
- "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
- "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
+ 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
+ ' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
+ " had informed his Lufthansa training school of an episode of severe depression, airline says .",
+ "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
+ " Israel and the United States opposed the move, which could open the door to war crimes investigations against"
+ " Israelis .",
+ "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
+ " death . Organization claims that governments around the world are using the threat of terrorism to advance"
+ " executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
+ " sentences up by 28% .",
]
@@ -65,7 +80,8 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
]
tgt = [
"Margot Frank, died in 1945, a month earlier than previously thought.",
- 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
+ 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
+ " the final seconds on board Flight 9525.",
]
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
diff --git a/examples/legacy/seq2seq/run_eval.py b/examples/legacy/seq2seq/run_eval.py
index e21f57c1c609bc..a8aa8e7ef95d23 100755
--- a/examples/legacy/seq2seq/run_eval.py
+++ b/examples/legacy/seq2seq/run_eval.py
@@ -121,7 +121,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
- help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
+ help=(
+ "use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
+ " lang=en-ru. If no value is passed, the current datetime string will be used."
+ ),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
diff --git a/examples/legacy/seq2seq/run_eval_search.py b/examples/legacy/seq2seq/run_eval_search.py
index f7b3bda0f54f07..e1a0c8660c9bf6 100755
--- a/examples/legacy/seq2seq/run_eval_search.py
+++ b/examples/legacy/seq2seq/run_eval_search.py
@@ -35,7 +35,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
- sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
+ sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
@@ -66,7 +66,10 @@ def run_search():
prog = sys.argv[0]
parser = argparse.ArgumentParser(
- usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
+ usage=(
+ "\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
+ " refer to `run_eval.py -h` for the complete list."
+ )
)
parser.add_argument(
"--search",
@@ -83,7 +86,10 @@ def run_search():
nargs="?",
type=str,
const=datetime_now(),
- help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
+ help=(
+ "add custom notes to be printed before the results table. If no value is passed, the current datetime"
+ " string will be used."
+ ),
)
args, args_main = parser.parse_known_args()
# we share some of the args
diff --git a/examples/legacy/seq2seq/seq2seq_trainer.py b/examples/legacy/seq2seq/seq2seq_trainer.py
index eeff082499c4d8..dbf12725f2db07 100644
--- a/examples/legacy/seq2seq/seq2seq_trainer.py
+++ b/examples/legacy/seq2seq/seq2seq_trainer.py
@@ -57,9 +57,10 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if config is None:
- assert isinstance(
- self.model, PreTrainedModel
- ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
+ assert isinstance(self.model, PreTrainedModel), (
+ "If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is"
+ f" {self.model.__class__}"
+ )
self.config = self.model.config
else:
self.config = config
@@ -68,13 +69,15 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
- assert (
- self.config.pad_token_id is not None
- ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
+ assert self.config.pad_token_id is not None, (
+ "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss"
+ " calculation or doing label smoothing."
+ )
if self.config.pad_token_id is None and self.config.eos_token_id is not None:
logger.warning(
- f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
+ f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for"
+ " padding.."
)
if self.args.label_smoothing == 0:
@@ -248,7 +251,8 @@ def _pad_tensors_to_max_len(self, tensor, max_length):
if pad_token_id is None:
raise ValueError(
- f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
+ "Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be"
+ f" padded to `max_length`={max_length}"
)
padded_tensor = pad_token_id * torch.ones(
diff --git a/examples/legacy/seq2seq/xla_spawn.py b/examples/legacy/seq2seq/xla_spawn.py
index d84b41994564a8..5df6bfa2d5dc31 100644
--- a/examples/legacy/seq2seq/xla_spawn.py
+++ b/examples/legacy/seq2seq/xla_spawn.py
@@ -39,9 +39,7 @@ def parse_args():
"""
parser = ArgumentParser(
description=(
- "PyTorch TPU distributed training launch "
- "helper utility that will spawn up "
- "multiple distributed processes"
+ "PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)
diff --git a/examples/legacy/text-classification/run_tf_text_classification.py b/examples/legacy/text-classification/run_tf_text_classification.py
index 3564775f30ddf2..1f845db04c0448 100755
--- a/examples/legacy/text-classification/run_tf_text_classification.py
+++ b/examples/legacy/text-classification/run_tf_text_classification.py
@@ -168,8 +168,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -215,7 +217,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/legacy/token-classification/run_ner.py b/examples/legacy/token-classification/run_ner.py
index a653ecb91c6930..477ccb50fb2565 100644
--- a/examples/legacy/token-classification/run_ner.py
+++ b/examples/legacy/token-classification/run_ner.py
@@ -87,8 +87,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -116,7 +118,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
module = import_module("tasks")
diff --git a/examples/legacy/token-classification/run_tf_ner.py b/examples/legacy/token-classification/run_tf_ner.py
index 0169a10f24ac6a..857d777238f2e2 100755
--- a/examples/legacy/token-classification/run_tf_ner.py
+++ b/examples/legacy/token-classification/run_tf_ner.py
@@ -88,8 +88,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -111,7 +113,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
module = import_module("tasks")
diff --git a/examples/legacy/token-classification/utils_ner.py b/examples/legacy/token-classification/utils_ner.py
index 2537aecfca6a0d..e1fb4d18c70b76 100644
--- a/examples/legacy/token-classification/utils_ner.py
+++ b/examples/legacy/token-classification/utils_ner.py
@@ -103,7 +103,7 @@ def convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in enumerate(examples):
+ for ex_index, example in enumerate(examples):
if ex_index % 10_000 == 0:
logger.info("Writing example %d of %d", ex_index, len(examples))
diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md
index c19bcfc955c3a9..95d42bfc8b3812 100644
--- a/examples/pytorch/README.md
+++ b/examples/pytorch/README.md
@@ -167,10 +167,10 @@ python xla_spawn.py --num_cores 8 \
Most PyTorch example scripts have a version using the [š¤ Accelerate](https://github.com/huggingface/accelerate) library
that exposes the training loop so it's easy for you to customize or tweak them to your needs. They all require you to
-install `accelerate` with
+install `accelerate` with the latest development version
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
Then you can easily launch any of the scripts by running
diff --git a/examples/pytorch/_tests_requirements.txt b/examples/pytorch/_tests_requirements.txt
index 9483d3a750a341..8c13e79aa44b5d 100644
--- a/examples/pytorch/_tests_requirements.txt
+++ b/examples/pytorch/_tests_requirements.txt
@@ -3,7 +3,7 @@ scikit-learn
seqeval
psutil
sacrebleu >= 1.4.12
-accelerate >= 0.5.0
+git+https://github.com/huggingface/accelerate@main#egg=accelerate
rouge-score
tensorflow_datasets
matplotlib
diff --git a/examples/pytorch/audio-classification/README.md b/examples/pytorch/audio-classification/README.md
index 12eb5e6ed399e4..21da5b9935ca14 100644
--- a/examples/pytorch/audio-classification/README.md
+++ b/examples/pytorch/audio-classification/README.md
@@ -18,13 +18,13 @@ limitations under the License.
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
-Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
-*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
-[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
-[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
+Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
+*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
+[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
+[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
very little annotated data to yield good performance on speech classification datasets.
-## Single-GPU
+## Single-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the š£ļø [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
š See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
-## Multi-GPU
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
+## Multi-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for š **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
-| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
+| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |
diff --git a/examples/pytorch/audio-classification/run_audio_classification.py b/examples/pytorch/audio-classification/run_audio_classification.py
index 5ad561ee2b85d5..88ca51af3f6be5 100644
--- a/examples/pytorch/audio-classification/run_audio_classification.py
+++ b/examples/pytorch/audio-classification/run_audio_classification.py
@@ -37,14 +37,14 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
@@ -86,8 +86,9 @@ class DataTrainingArguments:
eval_split_name: str = field(
default="validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
audio_column_name: str = field(
@@ -100,15 +101,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_length_seconds: float = field(
@@ -149,13 +154,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
@@ -186,6 +197,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_audio_classification", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -326,6 +341,7 @@ def compute_metrics(eval_pred):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder
diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py
index fc036f2a20fa2d..4ed5123ae0edef 100644
--- a/examples/pytorch/contrastive-image-text/run_clip.py
+++ b/examples/pytorch/contrastive-image-text/run_clip.py
@@ -47,14 +47,14 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
@@ -89,8 +89,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_vision_model: bool = field(
@@ -132,22 +134,28 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -225,6 +233,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clip", model_args, data_args)
+
# 2. Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/image-classification/README.md b/examples/pytorch/image-classification/README.md
index 2070c854c769b6..904981451c6f80 100644
--- a/examples/pytorch/image-classification/README.md
+++ b/examples/pytorch/image-classification/README.md
@@ -62,9 +62,11 @@ python run_image_classification.py \
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
### Using your own data
-To use your own dataset, there are 2 ways:
+To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
@@ -177,7 +179,7 @@ the means of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate)
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
diff --git a/examples/pytorch/image-classification/requirements.txt b/examples/pytorch/image-classification/requirements.txt
index a789fee85eef5d..aadc0e9088f868 100644
--- a/examples/pytorch/image-classification/requirements.txt
+++ b/examples/pytorch/image-classification/requirements.txt
@@ -1,3 +1,3 @@
torch>=1.5.0
torchvision>=0.6.0
-datasets>=1.8.0
\ No newline at end of file
+datasets>=1.17.0
diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py
index ba85814bd78469..ac6cf0238cb71b 100644
--- a/examples/pytorch/image-classification/run_image_classification.py
+++ b/examples/pytorch/image-classification/run_image_classification.py
@@ -45,7 +45,7 @@
TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -54,7 +54,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
@@ -93,15 +93,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -140,10 +144,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def collate_fn(examples):
@@ -165,6 +175,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -263,6 +277,7 @@ def compute_metrics(p):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
index 39f805b458cc23..76b2059a1b0eee 100644
--- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py
+++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
@@ -37,6 +37,7 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
@@ -46,11 +47,11 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
@@ -61,7 +62,10 @@ def parse_args():
"--dataset_name",
type=str,
default="cifar10",
- help="The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private, dataset).",
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset)."
+ ),
)
parser.add_argument("--train_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.")
@@ -69,15 +73,19 @@ def parse_args():
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--train_val_split",
@@ -155,7 +163,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -178,9 +201,16 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_image_classification_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
logger.info(accelerator.state)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -188,11 +218,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -274,6 +300,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
@@ -362,6 +389,10 @@ def collate_fn(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -370,12 +401,15 @@ def collate_fn(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("image_classification_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("image_classification_no_trainer", experiment_config)
# Get the metric function
metric = load_metric("accuracy")
@@ -469,12 +503,13 @@ def collate_fn(examples):
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
- outputs = model(**batch)
+ with torch.no_grad():
+ outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -491,10 +526,11 @@ def collate_fn(examples):
accelerator.log(
{
"accuracy": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py
index be65779fe3c882..9d90f2665b1828 100644
--- a/examples/pytorch/image-pretraining/run_mae.py
+++ b/examples/pytorch/image-pretraining/run_mae.py
@@ -34,7 +34,7 @@
ViTMAEForPreTraining,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -43,7 +43,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
@@ -74,15 +74,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -104,8 +108,9 @@ class ModelArguments:
model_name_or_path: str = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
config_name: Optional[str] = field(
@@ -114,8 +119,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
cache_dir: Optional[str] = field(
@@ -129,8 +136,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
mask_ratio: float = field(
@@ -166,6 +175,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mae", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/image-pretraining/run_mim.py b/examples/pytorch/image-pretraining/run_mim.py
index ed39be7a1a1530..8ad1dadae99555 100644
--- a/examples/pytorch/image-pretraining/run_mim.py
+++ b/examples/pytorch/image-pretraining/run_mim.py
@@ -37,7 +37,7 @@
TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -48,7 +48,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
@@ -87,15 +87,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -117,9 +121,11 @@ class ModelArguments:
model_name_or_path: str = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
- "checkpoint identifier on the hub. "
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a "
+ "checkpoint identifier on the hub. "
+ "Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -132,8 +138,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
cache_dir: Optional[str] = field(
@@ -148,20 +156,26 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
image_size: Optional[int] = field(
default=None,
metadata={
- "help": "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
+ "help": (
+ "The size (resolution) of each image. If not specified, will use `image_size` of the configuration."
+ )
},
)
patch_size: Optional[int] = field(
default=None,
metadata={
- "help": "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
+ "help": (
+ "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration."
+ )
},
)
encoder_stride: Optional[int] = field(
@@ -225,6 +239,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mim", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py
index 04a6b4c2679484..2cd8092b7fb8df 100755
--- a/examples/pytorch/language-modeling/run_clm.py
+++ b/examples/pytorch/language-modeling/run_clm.py
@@ -48,12 +48,12 @@
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -73,8 +73,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -84,8 +85,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -109,8 +112,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -141,24 +146,30 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -203,6 +214,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -390,7 +405,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py
index 3e7cfaa3aa73ad..73d1ae086371fe 100755
--- a/examples/pytorch/language-modeling/run_clm_no_trainer.py
+++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py
@@ -39,12 +39,12 @@
import transformers
from accelerate import Accelerator, DistributedType
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -52,11 +52,11 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -93,7 +93,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -167,7 +167,11 @@ def parse_args():
"--block_size",
type=int,
default=None,
- help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).",
+ help=(
+ "Optional input sequence length after tokenization. The training dataset will be truncated in block of"
+ " this size for training. Default to the model max input length for single sentence inputs (take into"
+ " account special tokens)."
+ ),
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -201,7 +205,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -225,20 +239,23 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -450,7 +467,7 @@ def group_texts(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
@@ -475,6 +492,10 @@ def group_texts(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -483,12 +504,15 @@ def group_texts(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("clm_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("clm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -572,15 +596,23 @@ def group_texts(examples):
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
- perplexity = math.exp(torch.mean(losses))
+ eval_loss = torch.mean(losses)
+ perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
- logger.info(f"epoch {epoch}: perplexity: {perplexity}")
+ logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
if args.with_tracking:
accelerator.log(
- {"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "perplexity": perplexity,
+ "eval_loss": eval_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py
index 477ccff9505257..0322ac30972766 100755
--- a/examples/pytorch/language-modeling/run_mlm.py
+++ b/examples/pytorch/language-modeling/run_mlm.py
@@ -47,12 +47,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -70,8 +70,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -81,8 +82,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -106,8 +109,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -147,8 +152,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -165,22 +172,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -211,6 +224,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py
index d7d8d011ac8710..32d42412e3deff 100755
--- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py
+++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py
@@ -39,12 +39,12 @@
import transformers
from accelerate import Accelerator, DistributedType
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
@@ -52,11 +52,11 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -96,7 +96,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -170,7 +170,9 @@ def parse_args():
"--max_seq_length",
type=int,
default=None,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
+ ),
)
parser.add_argument(
"--line_by_line",
@@ -210,7 +212,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -236,20 +248,23 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -493,7 +508,7 @@ def group_texts(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
@@ -521,6 +536,10 @@ def group_texts(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -529,12 +548,15 @@ def group_texts(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("mlm_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("mlm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -619,7 +641,8 @@ def group_texts(examples):
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
- perplexity = math.exp(torch.mean(losses))
+ eval_loss = torch.mean(losses)
+ perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
@@ -627,7 +650,14 @@ def group_texts(examples):
if args.with_tracking:
accelerator.log(
- {"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "perplexity": perplexity,
+ "eval_loss": eval_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py
index 8974882595aed5..78fc89f8305958 100755
--- a/examples/pytorch/language-modeling/run_plm.py
+++ b/examples/pytorch/language-modeling/run_plm.py
@@ -42,12 +42,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -63,8 +63,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
config_name: Optional[str] = field(
@@ -73,8 +74,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
tokenizer_name: Optional[str] = field(
@@ -95,8 +98,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -136,8 +141,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=512,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -147,8 +154,10 @@ class DataTrainingArguments:
plm_probability: float = field(
default=1 / 6,
metadata={
- "help": "Ratio of length of a span of masked tokens to surrounding context length for "
- "permutation language modeling."
+ "help": (
+ "Ratio of length of a span of masked tokens to surrounding context length for "
+ "permutation language modeling."
+ )
},
)
max_span_length: int = field(
@@ -161,22 +170,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -205,6 +220,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_plm", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/multiple-choice/README.md b/examples/pytorch/multiple-choice/README.md
index 4e3e331e05de60..735d1f5f33a017 100644
--- a/examples/pytorch/multiple-choice/README.md
+++ b/examples/pytorch/multiple-choice/README.md
@@ -53,7 +53,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py
index cd2bdd74ad2b85..32e31a7ff35c32 100755
--- a/examples/pytorch/multiple-choice/run_swag.py
+++ b/examples/pytorch/multiple-choice/run_swag.py
@@ -43,11 +43,11 @@
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import PaddingStrategy, check_min_version
+from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
logger = logging.getLogger(__name__)
@@ -82,8 +82,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -109,30 +111,38 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If passed, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to the maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to the maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -215,6 +225,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py
index 2c39d29cb108da..6e948a315bf08b 100755
--- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py
+++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py
@@ -37,12 +37,12 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
@@ -51,10 +51,10 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import PaddingStrategy, get_full_repo_name
+from transformers.utils import PaddingStrategy, get_full_repo_name, send_example_telemetry
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -98,7 +98,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -193,7 +193,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -263,20 +273,23 @@ def __call__(self, features):
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -450,7 +463,7 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
@@ -475,6 +488,10 @@ def preprocess_function(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -483,12 +500,15 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("swag_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("swag_no_trainer", experiment_config)
# Metrics
metric = load_metric("accuracy")
@@ -573,7 +593,7 @@ def preprocess_function(examples):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -588,7 +608,13 @@ def preprocess_function(examples):
if args.with_tracking:
accelerator.log(
- {"accuracy": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "accuracy": eval_metric,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/question-answering/README.md b/examples/pytorch/question-answering/README.md
index 480da1d89fdd02..f6e660e972d618 100644
--- a/examples/pytorch/question-answering/README.md
+++ b/examples/pytorch/question-answering/README.md
@@ -136,7 +136,7 @@ SQuAD or a similar dataset, the main difference is that this script exposes the
You can use the script normally after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py
index 242e8342738960..50c9557141c017 100755
--- a/examples/pytorch/question-answering/run_qa.py
+++ b/examples/pytorch/question-answering/run_qa.py
@@ -42,13 +42,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -81,8 +81,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -118,37 +120,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -157,9 +168,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -173,8 +186,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -211,6 +226,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -319,9 +338,9 @@ def main():
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Preprocessing the datasets.
diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py
index d46e96d21043d3..b73de15b452c9c 100755
--- a/examples/pytorch/question-answering/run_qa_beam_search.py
+++ b/examples/pytorch/question-answering/run_qa_beam_search.py
@@ -41,13 +41,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -80,8 +80,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -117,37 +119,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -156,9 +167,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -172,8 +185,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -210,6 +225,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_beam_search", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
index 6e365c9814ced5..d1547a49231f6c 100644
--- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
+++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
@@ -35,6 +35,7 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
@@ -48,17 +49,17 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
@@ -115,8 +116,10 @@ def parse_args():
"--max_seq_length",
type=int,
default=384,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+ " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -189,9 +192,11 @@ def parse_args():
"--null_score_diff_threshold",
type=float,
default=0.0,
- help="The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`.",
+ help=(
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ ),
)
parser.add_argument(
"--version_2_with_negative",
@@ -202,22 +207,28 @@ def parse_args():
"--max_answer_length",
type=int,
default=30,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
@@ -280,6 +291,10 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_beam_search_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
@@ -289,11 +304,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -736,6 +747,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py
index 530df23fd2d20f..8f6045386ae83e 100755
--- a/examples/pytorch/question-answering/run_qa_no_trainer.py
+++ b/examples/pytorch/question-answering/run_qa_no_trainer.py
@@ -35,12 +35,12 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
@@ -50,17 +50,17 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import check_min_version, get_full_repo_name
+from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -120,8 +120,10 @@ def parse_args():
"--max_seq_length",
type=int,
default=384,
- help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_lengh` is passed.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+ " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -132,7 +134,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -211,9 +213,11 @@ def parse_args():
"--null_score_diff_threshold",
type=float,
default=0.0,
- help="The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`.",
+ help=(
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ ),
)
parser.add_argument(
"--version_2_with_negative",
@@ -224,22 +228,28 @@ def parse_args():
"--max_answer_length",
type=int,
default=30,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
- help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set.",
+ help=(
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ ),
)
parser.add_argument(
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
@@ -277,7 +287,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -309,20 +329,23 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -721,7 +744,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -742,6 +765,10 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -750,12 +777,15 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("qa_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("qa_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -926,14 +956,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if args.with_tracking:
log = {
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
}
if args.do_predict:
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric
- accelerator.log(log)
+ accelerator.log(log, step=completed_steps)
if args.output_dir is not None:
accelerator.wait_for_everyone()
diff --git a/examples/pytorch/question-answering/run_seq2seq_qa.py b/examples/pytorch/question-answering/run_seq2seq_qa.py
index cb6bd09bc40db6..bd806cc033e810 100644
--- a/examples/pytorch/question-answering/run_seq2seq_qa.py
+++ b/examples/pytorch/question-answering/run_seq2seq_qa.py
@@ -39,12 +39,12 @@
set_seed,
)
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
@@ -81,8 +81,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -130,53 +132,66 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
val_max_answer_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_answer_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -185,9 +200,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -201,8 +218,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -252,6 +271,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_seq2seq_qa", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
index 304f8848b49b21..20e9b93a48c03f 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
@@ -42,7 +42,7 @@
default_data_collator,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
@@ -51,7 +51,7 @@
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
@@ -194,15 +194,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
reduce_labels: Optional[bool] = field(
@@ -241,8 +245,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -260,6 +266,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_semantic_segmentation", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
index d5a6a16fe4857e..37df263f5be4d4 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
@@ -16,7 +16,6 @@
import argparse
import json
-import logging
import math
import os
import random
@@ -34,6 +33,7 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, hf_hub_download
from transformers import (
@@ -44,11 +44,11 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
@@ -285,7 +285,17 @@ def parse_args():
"--with_tracking",
required=False,
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -305,14 +315,17 @@ def parse_args():
def main():
args = parse_args()
- # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
- logger.info(accelerator.state)
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_semantic_segmentation_no_trainer", args)
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -479,14 +492,22 @@ def preprocess_val(example_batch):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Instantiate metric
metric = load_metric("mean_iou")
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -579,7 +600,8 @@ def preprocess_val(example_batch):
model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
- outputs = model(**batch)
+ with torch.no_grad():
+ outputs = model(**batch)
upsampled_logits = torch.nn.functional.interpolate(
outputs.logits, size=batch["labels"].shape[-2:], mode="bilinear", align_corners=False
@@ -590,7 +612,7 @@ def preprocess_val(example_batch):
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -614,10 +636,11 @@ def preprocess_val(example_batch):
"mean_iou": eval_metrics["mean_iou"],
"mean_accuracy": eval_metrics["mean_accuracy"],
"overall_accuracy": eval_metrics["overall_accuracy"],
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
index 680808f2e72c4d..1f6125390da2db 100755
--- a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
+++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
@@ -16,7 +16,6 @@
""" Pre-Training a š¤ Wav2Vec2 model on unlabeled audio data """
import argparse
-import logging
import math
import os
from dataclasses import dataclass
@@ -31,6 +30,7 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from huggingface_hub import Repository
from transformers import (
AdamW,
@@ -43,10 +43,10 @@
set_seed,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
def parse_args():
@@ -219,7 +219,10 @@ def parse_args():
"--pad_to_multiple_of",
type=int,
default=None,
- help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
+ help=(
+ "If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the"
+ " use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
+ ),
)
parser.add_argument(
"--adam_beta1",
@@ -360,13 +363,13 @@ def main():
# We now keep distinct sets of args, for a cleaner separation of concerns.
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_wav2vec2_pretraining_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -444,7 +447,7 @@ def main():
# only normalized-inputs-training is supported
if not feature_extractor.do_normalize:
raise ValueError(
- "Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
+ "Training is only supported for normalized inputs. Make sure ``feature_extractor.do_normalize == True``"
)
# set max & min audio length in number of samples
@@ -500,7 +503,8 @@ def prepare_dataset(batch):
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
# initialize random model
@@ -619,7 +623,7 @@ def prepare_dataset(batch):
lr_scheduler.step()
elif accelerator.is_local_main_process:
progress_bar.write(
- "Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
+ f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
)
# update gumbel temperature
diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
index 6df37086240df0..ad2425d9fbb87a 100755
--- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
+++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
@@ -44,12 +44,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
@@ -101,9 +101,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -113,8 +115,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -146,8 +151,10 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train+validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'train+validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to "
+ "'train+validation'"
+ )
},
)
eval_split_name: str = field(
@@ -174,15 +181,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -196,7 +207,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -205,17 +219,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -233,10 +251,12 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -356,6 +376,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_speech_recognition_ctc", model_args, data_args)
+
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
@@ -405,9 +429,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.text_column_name not in raw_datasets["train"].column_names:
@@ -481,7 +505,12 @@ def remove_special_characters(batch):
with training_args.main_process_first():
if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
- os.remove(vocab_file)
+ try:
+ os.remove(vocab_file)
+ except OSError:
+ # in shared file-systems it might be the case that
+ # two processes try to delete the vocab file at the some time
+ pass
with training_args.main_process_first(desc="dataset map vocabulary creation"):
if not os.path.isfile(vocab_file):
@@ -720,7 +749,10 @@ def compute_metrics(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
index 3c368c4ae8362a..fce6b55be17856 100755
--- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
+++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
@@ -42,12 +42,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
@@ -87,8 +87,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
freeze_feature_encoder: bool = field(
@@ -122,15 +124,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
audio_column_name: str = field(
@@ -144,7 +150,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -153,10 +162,12 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
train_split_name: str = field(
@@ -228,6 +239,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args)
+
# 2. Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/summarization/README.md b/examples/pytorch/summarization/README.md
index bf42e796434eed..db7f8f4061a5c9 100644
--- a/examples/pytorch/summarization/README.md
+++ b/examples/pytorch/summarization/README.md
@@ -149,7 +149,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py
index c35b636d7dd9b6..95be07e7185dad 100755
--- a/examples/pytorch/summarization/run_summarization.py
+++ b/examples/pytorch/summarization/run_summarization.py
@@ -46,12 +46,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -101,15 +101,19 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
- "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
- "the model's position embeddings."
+ "help": (
+ "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
+ "the model's position embeddings."
+ )
},
)
@@ -142,14 +146,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -162,60 +167,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -231,9 +252,11 @@ class DataTrainingArguments:
forced_bos_token: Optional[str] = field(
default=None,
metadata={
- "help": "The token to force as the first generated token after the decoder_start_token_id."
- "Useful for multilingual models like mBART where the first generated token"
- "needs to be the target language token (Usually it is the target language token)"
+ "help": (
+ "The token to force as the first generated token after the decoder_start_token_id."
+ "Useful for multilingual models like mBART where the first generated token"
+ "needs to be the target language token (Usually it is the target language token)"
+ )
},
)
@@ -279,6 +302,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -410,17 +437,18 @@ def main():
):
if model_args.resize_position_embeddings is None:
logger.warning(
- f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} "
- f"to {data_args.max_source_length}."
+ "Increasing the model's number of position embedding vectors from"
+ f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
)
model.resize_position_embeddings(data_args.max_source_length)
elif model_args.resize_position_embeddings:
model.resize_position_embeddings(data_args.max_source_length)
else:
raise ValueError(
- f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
- f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
- "resize the model's position encodings by passing `--resize_position_embeddings`."
+ f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
+ f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
+ f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
+ " model's position encodings by passing `--resize_position_embeddings`."
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py
index e08edbf51301ac..98c7f09bd4f01f 100644
--- a/examples/pytorch/summarization/run_summarization_no_trainer.py
+++ b/examples/pytorch/summarization/run_summarization_no_trainer.py
@@ -36,13 +36,13 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from filelock import FileLock
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@@ -50,11 +50,11 @@
SchedulerType,
get_scheduler,
)
-from transformers.utils import get_full_repo_name, is_offline_mode
+from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
# You should update this to your particular problem to have better documentation of `model_type`
@@ -110,20 +110,22 @@ def parse_args():
"--ignore_pad_token_for_loss",
type=bool,
default=True,
- help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
+ help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
- help="The maximum total input sequence length after "
- "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after "
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--source_prefix",
type=str,
default=None,
- help="A prefix to add before every source text " "(useful for T5 models).",
+ help="A prefix to add before every source text (useful for T5 models).",
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -138,18 +140,22 @@ def parse_args():
"--max_target_length",
type=int,
default=128,
- help="The maximum total sequence length for target text after "
- "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
- "during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for target text after "
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+ "during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--val_max_target_length",
type=int,
default=None,
- help="The maximum total sequence length for validation "
- "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
- "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
- "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for validation "
+ "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+ "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+ "param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--max_length",
@@ -164,8 +170,10 @@ def parse_args():
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
@@ -176,7 +184,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -278,7 +286,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -301,7 +319,16 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization_no_trainer", args)
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
if args.source_prefix is None and args.model_name_or_path in [
"t5-small",
"t5-base",
@@ -313,20 +340,13 @@ def main():
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
- # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -517,7 +537,7 @@ def postprocess_text(preds, labels):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -538,6 +558,10 @@ def postprocess_text(preds, labels):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -546,12 +570,15 @@ def postprocess_text(preds, labels):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("summarization_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("summarization_no_trainer", experiment_config)
# Metric
metric = load_metric("rouge")
@@ -666,11 +693,11 @@ def postprocess_text(preds, labels):
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
- samples_seen += decoded_labels.shape[0]
+ samples_seen += len(decoded_labels)
metric.add_batch(
predictions=decoded_preds,
@@ -685,10 +712,10 @@ def postprocess_text(preds, labels):
logger.info(result)
if args.with_tracking:
- result["train_loss"] = total_loss
+ result["train_loss"] = total_loss.item() / len(train_dataloader)
result["epoch"] = epoch
result["step"] = completed_steps
- accelerator.log(result)
+ accelerator.log(result, step=completed_steps)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py
index 14eef9c7f77228..6e17826727e472 100644
--- a/examples/pytorch/test_accelerate_examples.py
+++ b/examples/pytorch/test_accelerate_examples.py
@@ -18,49 +18,18 @@
import json
import logging
import os
+import shutil
+import subprocess
import sys
-from unittest.mock import patch
+import tempfile
import torch
+from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available
-SRC_DIRS = [
- os.path.join(os.path.dirname(__file__), dirname)
- for dirname in [
- "text-generation",
- "text-classification",
- "token-classification",
- "language-modeling",
- "multiple-choice",
- "question-answering",
- "summarization",
- "translation",
- "image-classification",
- "speech-recognition",
- "audio-classification",
- "speech-pretraining",
- "image-pretraining",
- "semantic-segmentation",
- ]
-]
-sys.path.extend(SRC_DIRS)
-
-
-if SRC_DIRS is not None:
- import run_clm_no_trainer
- import run_glue_no_trainer
- import run_image_classification_no_trainer
- import run_mlm_no_trainer
- import run_ner_no_trainer
- import run_qa_no_trainer as run_squad_no_trainer
- import run_semantic_segmentation_no_trainer
- import run_summarization_no_trainer
- import run_swag_no_trainer
- import run_translation_no_trainer
-
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
@@ -94,10 +63,22 @@ def is_cuda_and_apex_available():
class ExamplesTestsNoTrainer(TestCasePlus):
+ @classmethod
+ def setUpClass(cls):
+ # Write Accelerate config, will pick up on CPU, GPU, and multi-GPU
+ cls.tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls.tmpdir, "default_config.yml")
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmpdir)
+
def test_run_glue_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_glue_no_trainer.py
+ {self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
@@ -113,17 +94,16 @@ def test_run_glue_no_trainer(self):
if is_cuda_and_apex_available():
testargs.append("--fp16")
- with patch.object(sys, "argv", testargs):
- run_glue_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.75)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.75)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
def test_run_clm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_clm_no_trainer.py
+ {self.examples_dir}/pytorch/language-modeling/run_clm_no_trainer.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
@@ -140,17 +120,16 @@ def test_run_clm_no_trainer(self):
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
- with patch.object(sys, "argv", testargs):
- run_clm_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertLess(result["perplexity"], 100)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertLess(result["perplexity"], 100)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
def test_run_mlm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_mlm_no_trainer.py
+ {self.examples_dir}/pytorch/language-modeling/run_mlm_no_trainer.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
@@ -160,12 +139,11 @@ def test_run_mlm_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_mlm_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertLess(result["perplexity"], 42)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertLess(result["perplexity"], 42)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
@@ -173,7 +151,7 @@ def test_run_ner_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_ner_no_trainer.py
+ {self.examples_dir}/pytorch/token-classification/run_ner_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
@@ -187,18 +165,17 @@ def test_run_ner_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_ner_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.75)
- self.assertLess(result["train_loss"], 0.5)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.75)
+ self.assertLess(result["train_loss"], 0.5)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_qa_no_trainer.py
+ {self.examples_dir}/pytorch/question-answering/run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
@@ -213,19 +190,18 @@ def test_run_squad_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_squad_no_trainer.main()
- result = get_results(tmp_dir)
- # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
- self.assertGreaterEqual(result["eval_f1"], 30)
- self.assertGreaterEqual(result["eval_exact"], 30)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
+ self.assertGreaterEqual(result["eval_f1"], 28)
+ self.assertGreaterEqual(result["eval_exact"], 28)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_swag_no_trainer.py
+ {self.examples_dir}/pytorch/multiple-choice/run_swag_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/swag/sample.json
--validation_file tests/fixtures/tests_samples/swag/sample.json
@@ -238,17 +214,16 @@ def test_run_swag_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_swag_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.8)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.8)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
@slow
def test_run_summarization_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_summarization_no_trainer.py
+ {self.examples_dir}/pytorch/summarization/run_summarization_no_trainer.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
@@ -262,21 +237,20 @@ def test_run_summarization_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_summarization_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_rouge1"], 10)
- self.assertGreaterEqual(result["eval_rouge2"], 2)
- self.assertGreaterEqual(result["eval_rougeL"], 7)
- self.assertGreaterEqual(result["eval_rougeLsum"], 7)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_rouge1"], 10)
+ self.assertGreaterEqual(result["eval_rouge2"], 2)
+ self.assertGreaterEqual(result["eval_rougeL"], 7)
+ self.assertGreaterEqual(result["eval_rougeLsum"], 7)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
@slow
def test_run_translation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_translation_no_trainer.py
+ {self.examples_dir}/pytorch/translation/run_translation_no_trainer.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--source_lang en
--target_lang ro
@@ -294,12 +268,11 @@ def test_run_translation_no_trainer(self):
--with_tracking
""".split()
- with patch.object(sys, "argv", testargs):
- run_translation_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_bleu"], 30)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_bleu"], 30)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
@slow
def test_run_semantic_segmentation_no_trainer(self):
@@ -308,7 +281,7 @@ def test_run_semantic_segmentation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_semantic_segmentation_no_trainer.py
+ {self.examples_dir}/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
--dataset_name huggingface/semantic-segmentation-test-sample
--output_dir {tmp_dir}
--max_train_steps=10
@@ -319,15 +292,14 @@ def test_run_semantic_segmentation_no_trainer(self):
--checkpointing_steps epoch
""".split()
- with patch.object(sys, "argv", testargs):
- run_semantic_segmentation_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
- run_image_classification_no_trainer.py
+ {self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample
--output_dir {tmp_dir}
--num_warmup_steps=8
@@ -339,9 +311,8 @@ def test_run_image_classification_no_trainer(self):
--seed 42
""".split()
- with patch.object(sys, "argv", testargs):
- run_image_classification_no_trainer.main()
- result = get_results(tmp_dir)
- self.assertGreaterEqual(result["eval_accuracy"], 0.50)
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
- self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
+ _ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
+ result = get_results(tmp_dir)
+ self.assertGreaterEqual(result["eval_accuracy"], 0.50)
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
diff --git a/examples/pytorch/text-classification/README.md b/examples/pytorch/text-classification/README.md
index 5f853149e346fe..391aaf4d3f038f 100644
--- a/examples/pytorch/text-classification/README.md
+++ b/examples/pytorch/text-classification/README.md
@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
-and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
+and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
(the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
@@ -79,6 +79,8 @@ python run_glue.py \
--output_dir /tmp/imdb/
```
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
### Mixed precision training
@@ -115,7 +117,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py
index b15a0378ca7d2f..22f5497399aa0d 100755
--- a/examples/pytorch/text-classification/run_glue.py
+++ b/examples/pytorch/text-classification/run_glue.py
@@ -42,12 +42,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -89,8 +89,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -99,29 +101,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -180,10 +190,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def main():
@@ -199,6 +215,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -352,6 +372,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the raw_datasets
diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py
index 73f52825a3b600..4e73a10e9a3302 100644
--- a/examples/pytorch/text-classification/run_glue_no_trainer.py
+++ b/examples/pytorch/text-classification/run_glue_no_trainer.py
@@ -22,16 +22,17 @@
from pathlib import Path
import datasets
+import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
- AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
@@ -41,11 +42,11 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -166,7 +167,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -189,21 +205,23 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue_no_trainer", args)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -255,7 +273,7 @@ def main():
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
- extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
+ extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -290,6 +308,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
@@ -327,7 +346,7 @@ def main():
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
- elif args.task_name is None:
+ elif args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)}
if label_to_id is not None:
@@ -399,7 +418,7 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -420,6 +439,10 @@ def preprocess_function(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -428,12 +451,15 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("glue_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("glue_no_trainer", experiment_config)
# Get the metric function
if args.task_name is not None:
@@ -514,12 +540,13 @@ def preprocess_function(examples):
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
- outputs = model(**batch)
+ with torch.no_grad():
+ outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -536,10 +563,11 @@ def preprocess_function(examples):
accelerator.log(
{
"accuracy" if args.task_name is not None else "glue": eval_metric,
- "train_loss": total_loss,
+ "train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/text-classification/run_xnli.py b/examples/pytorch/text-classification/run_xnli.py
index cd4d44b6a61e0a..d0a449c3521c3a 100755
--- a/examples/pytorch/text-classification/run_xnli.py
+++ b/examples/pytorch/text-classification/run_xnli.py
@@ -42,12 +42,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
@@ -67,8 +67,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -77,33 +79,39 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
- server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
- server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
@dataclass
@@ -146,10 +154,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
def main():
@@ -160,14 +174,9 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
- # Setup distant debugging if needed
- if data_args.server_ip and data_args.server_port:
- # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
- import ptvsd
-
- print("Waiting for debugger attach")
- ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
- ptvsd.wait_for_attach()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_xnli", model_args)
# Setup logging
logging.basicConfig(
@@ -279,6 +288,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
diff --git a/examples/pytorch/token-classification/README.md b/examples/pytorch/token-classification/README.md
index 01f586dff2fea5..496722cf6b9a14 100644
--- a/examples/pytorch/token-classification/README.md
+++ b/examples/pytorch/token-classification/README.md
@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
of the script.
+> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
+
## Old version of the script
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).
@@ -73,7 +75,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py
index fc54c77d620477..bffc4395fd21da 100755
--- a/examples/pytorch/token-classification/run_ner.py
+++ b/examples/pytorch/token-classification/run_ner.py
@@ -43,12 +43,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
@@ -81,10 +81,16 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
+ ignore_mismatched_sizes: bool = field(
+ default=False,
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
+ )
@dataclass
@@ -127,44 +133,56 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -198,6 +216,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -350,14 +372,15 @@ def get_label_list(labels):
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Model has labels -> use them.
@@ -373,8 +396,8 @@ def get_label_list(labels):
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
# Set the correspondences label/ID inside the model config
diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py
index 6281ee162d261d..4910b30e04d608 100755
--- a/examples/pytorch/token-classification/run_ner_no_trainer.py
+++ b/examples/pytorch/token-classification/run_ner_no_trainer.py
@@ -34,12 +34,12 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
@@ -49,11 +49,11 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
# You should update this to your particular problem to have better documentation of `model_type`
@@ -113,7 +113,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -220,7 +220,22 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
+ )
+ parser.add_argument(
+ "--ignore_mismatched_sizes",
+ action="store_true",
+ help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
@@ -244,20 +259,23 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -386,6 +404,7 @@ def get_label_list(labels):
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
+ ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
else:
logger.info("Training new model from scratch")
@@ -406,8 +425,8 @@ def get_label_list(labels):
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
# Set the correspondences label/ID inside the model config
@@ -510,7 +529,7 @@ def tokenize_and_align_labels(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
@@ -535,6 +554,10 @@ def tokenize_and_align_labels(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -543,12 +566,15 @@ def tokenize_and_align_labels(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("ner_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("ner_no_trainer", experiment_config)
# Metrics
metric = load_metric("seqeval")
@@ -676,7 +702,7 @@ def compute_metrics():
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
else:
@@ -691,7 +717,13 @@ def compute_metrics():
accelerator.print(f"epoch {epoch}:", eval_metric)
if args.with_tracking:
accelerator.log(
- {"seqeval": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "seqeval": eval_metric,
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
@@ -724,7 +756,9 @@ def compute_metrics():
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
- json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f)
+ json.dump(
+ {"eval_accuracy": eval_metric["accuracy"], "train_loss": total_loss.item() / len(train_dataloader)}, f
+ )
if __name__ == "__main__":
diff --git a/examples/pytorch/translation/README.md b/examples/pytorch/translation/README.md
index 00c03a9be139fc..4bd66ea0acd130 100644
--- a/examples/pytorch/translation/README.md
+++ b/examples/pytorch/translation/README.md
@@ -162,7 +162,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then
diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py
index 6f2630104f7ea5..1cd55a6f4a2dcd 100755
--- a/examples/pytorch/translation/run_translation.py
+++ b/examples/pytorch/translation/run_translation.py
@@ -46,12 +46,12 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
@@ -91,8 +91,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -116,15 +118,12 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
- "a jsonlines file."
+ "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
- metadata={
- "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
- },
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacreblue) on a jsonlines file."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
@@ -136,60 +135,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -204,9 +219,11 @@ class DataTrainingArguments:
forced_bos_token: Optional[str] = field(
default=None,
metadata={
- "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
- "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
- "needs to be the target language token.(Usually it is the target language token)"
+ "help": (
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
+ " be the target language token.(Usually it is the target language token)"
+ )
},
)
@@ -243,6 +260,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation", model_args, data_args)
+
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py
index f01267a288511e..acc49ffdfcd218 100644
--- a/examples/pytorch/translation/run_translation_no_trainer.py
+++ b/examples/pytorch/translation/run_translation_no_trainer.py
@@ -35,12 +35,12 @@
import transformers
from accelerate import Accelerator
+from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
- AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@@ -51,11 +51,11 @@
default_data_collator,
get_scheduler,
)
-from transformers.utils import get_full_repo_name
+from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
# You should update this to your particular problem to have better documentation of `model_type`
@@ -94,41 +94,51 @@ def parse_args():
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--max_source_length",
type=int,
default=1024,
- help="The maximum total input sequence length after "
- "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after "
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
type=int,
default=128,
- help="The maximum total sequence length for target text after "
- "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
- "during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for target text after "
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+ "during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--val_max_target_length",
type=int,
default=None,
- help="The maximum total sequence length for validation "
- "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
- "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
- "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "The maximum total sequence length for validation "
+ "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+ "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+ "param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--pad_to_max_length",
type=bool,
default=False,
- help="Whether to pad all samples to model maximum sentence "
- "length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
- "efficient on GPU but very bad for TPU.",
+ help=(
+ "Whether to pad all samples to model maximum sentence "
+ "length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
+ "efficient on GPU but very bad for TPU."
+ ),
)
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
@@ -137,7 +147,7 @@ def parse_args():
"--ignore_pad_token_for_loss",
type=bool,
default=True,
- help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
+ help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
)
parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.")
parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.")
@@ -145,7 +155,7 @@ def parse_args():
"--source_prefix",
type=str,
default=None,
- help="A prefix to add before every source text " "(useful for T5 models).",
+ help="A prefix to add before every source text (useful for T5 models).",
)
parser.add_argument(
"--preprocessing_num_workers",
@@ -169,7 +179,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
+ required=False,
)
parser.add_argument(
"--config_name",
@@ -259,7 +269,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
- help="Whether to load in all available experiment trackers from the environment and use them for logging.",
+ help="Whether to enable experiment trackers for logging.",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="all",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
+ "Only applicable when `--with_tracking` is passed."
+ ),
)
args = parser.parse_args()
@@ -285,9 +305,16 @@ def main():
# Parse the arguments
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation_no_trainer", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
- # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
- accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
+ # in the environment
+ accelerator = (
+ Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
+ )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -295,11 +322,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
- logger.info(accelerator.state)
-
- # Setup logging, we only want one process per machine to log things on the screen.
- # accelerator.is_local_main_process is only True for one process per machine.
- logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
@@ -495,7 +518,7 @@ def preprocess_function(examples):
"weight_decay": 0.0,
},
]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -516,6 +539,10 @@ def preprocess_function(examples):
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
@@ -524,12 +551,15 @@ def preprocess_function(examples):
else:
checkpointing_steps = None
- # We need to initialize the trackers we use, and also store our configuration
+ # We need to initialize the trackers we use, and also store our configuration.
+ # We initialize the trackers only on main process because `accelerator.log`
+ # only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
- experiment_config = vars(args)
- # TensorBoard cannot log Enums, need the raw value
- experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
- accelerator.init_trackers("translation_no_trainer", experiment_config)
+ if accelerator.is_main_process:
+ experiment_config = vars(args)
+ # TensorBoard cannot log Enums, need the raw value
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
+ accelerator.init_trackers("translation_no_trainer", experiment_config)
metric = load_metric("sacrebleu")
@@ -650,11 +680,11 @@ def postprocess_text(preds, labels):
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
- if step == len(eval_dataloader):
+ if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
- samples_seen += decoded_labels.shape[0]
+ samples_seen += len(decoded_labels)
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
eval_metric = metric.compute()
@@ -662,7 +692,13 @@ def postprocess_text(preds, labels):
if args.with_tracking:
accelerator.log(
- {"blue": eval_metric["score"], "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
+ {
+ "blue": eval_metric["score"],
+ "train_loss": total_loss.item() / len(train_dataloader),
+ "epoch": epoch,
+ "step": completed_steps,
+ },
+ step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
diff --git a/examples/pytorch/xla_spawn.py b/examples/pytorch/xla_spawn.py
index d84b41994564a8..5df6bfa2d5dc31 100644
--- a/examples/pytorch/xla_spawn.py
+++ b/examples/pytorch/xla_spawn.py
@@ -39,9 +39,7 @@ def parse_args():
"""
parser = ArgumentParser(
description=(
- "PyTorch TPU distributed training launch "
- "helper utility that will spawn up "
- "multiple distributed processes"
+ "PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)
diff --git a/examples/research_projects/adversarial/run_hans.py b/examples/research_projects/adversarial/run_hans.py
index 31acbd3a8a6fd9..0576471fbc50a6 100644
--- a/examples/research_projects/adversarial/run_hans.py
+++ b/examples/research_projects/adversarial/run_hans.py
@@ -77,8 +77,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -110,7 +112,8 @@ def main():
and not training_args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
+ " --overwrite_output_dir to overcome."
)
# Setup logging
diff --git a/examples/research_projects/adversarial/utils_hans.py b/examples/research_projects/adversarial/utils_hans.py
index b02bf81352778b..e54792ad2f82b9 100644
--- a/examples/research_projects/adversarial/utils_hans.py
+++ b/examples/research_projects/adversarial/utils_hans.py
@@ -197,7 +197,7 @@ def __init__(
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
def gen():
- for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
+ for ex_index, ex in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
@@ -268,7 +268,7 @@ def get_labels(self):
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
- for (i, line) in enumerate(lines):
+ for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
@@ -303,7 +303,7 @@ def hans_convert_examples_to_features(
label_map = {label: i for i, label in enumerate(label_list)}
features = []
- for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
+ for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d" % (ex_index))
diff --git a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
index 006ff98c950f81..5e17352dc19b54 100644
--- a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
+++ b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_albert.py
@@ -84,7 +84,10 @@ def reset_stats(self):
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
- message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ message = (
+ f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
+ f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ )
print(message)
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
diff --git a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
index ff5c2b51e8b359..b32f47d0c30020 100644
--- a/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
+++ b/examples/research_projects/bert-loses-patience/pabee/modeling_pabee_bert.py
@@ -89,7 +89,10 @@ def reset_stats(self):
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
- message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ message = (
+ f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up ="
+ f" {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
+ )
print(message)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
diff --git a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
index def4dff7766428..d4121655e8233d 100755
--- a/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
+++ b/examples/research_projects/bert-loses-patience/run_glue_with_pabee.py
@@ -483,8 +483,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -574,8 +576,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument(
"--local_rank",
diff --git a/examples/research_projects/bertabs/run_summarization.py b/examples/research_projects/bertabs/run_summarization.py
index 33be67233ff6da..fcfae6b8c6c755 100644
--- a/examples/research_projects/bertabs/run_summarization.py
+++ b/examples/research_projects/bertabs/run_summarization.py
@@ -325,7 +325,8 @@ def main():
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
- "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
+ "We could not find the directory you specified for the documents to summarize, or it was empty. Please"
+ " specify a valid path."
)
os.makedirs(args.summaries_output_dir, exist_ok=True)
diff --git a/examples/research_projects/bertology/run_bertology.py b/examples/research_projects/bertology/run_bertology.py
index 1018359dc62e0c..030573d87f3532 100644
--- a/examples/research_projects/bertology/run_bertology.py
+++ b/examples/research_projects/bertology/run_bertology.py
@@ -338,8 +338,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. \n"
- "Sequences longer than this will be truncated, sequences shorter padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. \n"
+ "Sequences longer than this will be truncated, sequences shorter padded."
+ ),
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
diff --git a/examples/research_projects/bertology/run_prune_gpt.py b/examples/research_projects/bertology/run_prune_gpt.py
index 49a867b96dd4ce..68cece6e997ad2 100644
--- a/examples/research_projects/bertology/run_prune_gpt.py
+++ b/examples/research_projects/bertology/run_prune_gpt.py
@@ -314,8 +314,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. \n"
- "Sequences longer than this will be truncated, sequences shorter padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. \n"
+ "Sequences longer than this will be truncated, sequences shorter padded."
+ ),
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
diff --git a/examples/research_projects/codeparrot/README.md b/examples/research_projects/codeparrot/README.md
index 2b51b3ba4b572f..761b77a6df977c 100644
--- a/examples/research_projects/codeparrot/README.md
+++ b/examples/research_projects/codeparrot/README.md
@@ -37,24 +37,39 @@ Additionally, sure you have git-lfs installed. You can find instructions for how
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
### Preprocessing
-The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374):
+The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:
- exact deduplication using each file's hash
- filtering files with max line length > 1000
- filtering files with mean line length > 100
- fraction of alphanumeric characters < 0.25
- containing the word "auto-generated" or similar in the first 5 lines
+- filtering with a probability of 0.7 of files with a mention of "test file" or "configuration file" or similar in the first 5 lines
+- filtering with a probability of 0.7 of files with high occurence of the keywords "test " or "config"
+- filtering with a probability of 0.7 of files without a mention of the keywords `def` , `for`, `while` and `class`
+- filtering files that use the assignment operator `=` less than 5 times
+- filtering files with ratio between number of characters and number of tokens after tokenization < 1.5 (the average ratio is 3.6)
-The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/lvwerra/codeparrot-clean-train) and [validation](https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid) splits are also available on the Hub if you want to skip this step or use the data for another project.
+The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 vCPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/loubnabnl/codeparrot-clean-train-v2) and [validation](https://huggingface.co/datasets/loubnabnl/codeparrot-clean-valid-v2) splits are also available on the Hub if you want to skip this step or use the data for another project.
To execute the preprocessing run the following command:
```bash
python scripts/preprocessing.py \
---dataset_name lvwerra/codeparrot \
+--dataset_name transformersbook/codeparrot \
--output_dir codeparrot-clean
```
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
+### Pretokenization
+The tokenization of the data might be slow during the training especially for small models. We provide code to pretokenize the data beforehand in `scripts/pretokenizing.py`, but this step is optional. The dataset is downloaded and stored locally and the tokenized data is pushed to the hub. The tokenized clean [train](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-train) and [validation](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-valid) datasets are available if you want to use them directly.
+
+To execute the pretokenization, for the clean train data for instance, run the following command:
+```bash
+python scripts/pretokenizing.py \
+--dataset_name lvwerra/codeparrot-clean-train \
+--tokenized_data_repo tokenized-codeparrot-train
+```
+
## Tokenizer
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
```bash
@@ -77,7 +92,8 @@ python scripts/initialize_model.py \
```
This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.
-Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
+We can either pass the name of a text dataset or a pretokenized dataset which speeds up training a bit.
+Now that the tokenizer and model are also ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
First you need to configure `accelerate` and login to Weights & Biases:
@@ -89,7 +105,7 @@ wandb login
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
```bash
-python scripts/codeparrot_training.py
+accelerate launch scripts/codeparrot_training.py
```
If you want to train the small model you need to make some modifications:
@@ -149,7 +165,7 @@ python scripts/validation_loss.py \
In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. You can run the evaluation with the following command:
```bash
-python scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
+accelerate launch scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
--do_sample True \
--temperature 0.2 \
--top_p 0.95 \
@@ -162,7 +178,7 @@ The results as well as reference values are shown in the following table:
| Model | pass@1 | pass@10 | pass@100|
|-------|--------|---------|---------|
|CodeParrot š¦ (110M) | 3.80% | 6.57% | 12.78% |
-|CodeParrot š¦ (1.5B) | 3.58% | 8.03% | 14.96% |
+|CodeParrot š¦ (1.5B) | 3.99% | 8.69% | 17.88% |
|||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|Codex (85M)| 8.22% | 12.81% | 22.40% |
diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt
index a8aadb4ed9734a..267bcb9cb047c5 100644
--- a/examples/research_projects/codeparrot/requirements.txt
+++ b/examples/research_projects/codeparrot/requirements.txt
@@ -1,7 +1,7 @@
-transformers==4.15.0
+transformers==4.19.0
datasets==1.16.0
-accelerate==0.6.2
wandb==0.12.0
tensorboard==2.6.0
-torch==1.9.0
-huggingface-hub==0.1.0
\ No newline at end of file
+torch==1.11.0
+huggingface-hub==0.1.0
+git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
\ No newline at end of file
diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py
index a94cda2d2f1b41..03d578cbb86048 100644
--- a/examples/research_projects/codeparrot/scripts/arguments.py
+++ b/examples/research_projects/codeparrot/scripts/arguments.py
@@ -9,12 +9,10 @@ class TrainingArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be trained."},
+ default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be trained."}
)
save_dir: Optional[str] = field(
- default="./",
- metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
+ default="./", metadata={"help": "Save dir where model repo is cloned and models updates are saved to."}
)
dataset_name_train: Optional[str] = field(
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
@@ -26,7 +24,7 @@ class TrainingArguments:
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
shuffle_buffer: Optional[int] = field(
- default=1000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
+ default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
)
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
@@ -39,7 +37,7 @@ class TrainingArguments:
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
)
- max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
+ max_train_steps: Optional[int] = field(default=50000, metadata={"help": "Maximum number of training steps."})
max_eval_steps: Optional[int] = field(
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
)
@@ -50,9 +48,9 @@ class TrainingArguments:
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
resume_from_checkpoint: Optional[str] = field(
- default=None,
- metadata={"help": "States path if the training should continue from a checkpoint folder."},
+ default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
)
+ tokenized: Optional[bool] = field(default=False, metadata={"help": "If True the data is pretokenized."})
@dataclass
@@ -62,8 +60,7 @@ class EvaluationArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be evaluated."},
+ default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
dataset_name: Optional[str] = field(
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
@@ -83,8 +80,7 @@ class HumanEvalArguments:
"""
model_ckpt: Optional[str] = field(
- default="lvwerra/codeparrot",
- metadata={"help": "Model name or path of model to be evaluated."},
+ default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
num_tasks: Optional[int] = field(
@@ -112,7 +108,10 @@ class HumanEvalArguments:
device_int: Optional[int] = field(
default=-1,
metadata={
- "help": "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive number corresponds to which GPU device id to run on."
+ "help": (
+ "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive"
+ " number corresponds to which GPU device id to run on."
+ )
},
)
@@ -130,7 +129,7 @@ class PreprocessingArguments:
},
)
dataset_name: Optional[str] = field(
- default="codeparrot", metadata={"help": "Folder or name of dataset to process."}
+ default="transformersbook/codeparrot", metadata={"help": "Folder or name of dataset to process."}
)
output_dir: Optional[str] = field(
default="codeparrot-clean", metadata={"help": "Folder to save processed processed dataset."}
@@ -148,6 +147,16 @@ class PreprocessingArguments:
alpha_frac: Optional[float] = field(
default=0.25, metadata={"help": "Maximum fraction of non-alphanumeric characters, otherwise file is filtered."}
)
+ min_token_ratio: Optional[float] = field(
+ default=1.5, metadata={"help": "Minimum character token ratio for the file, otherwise file is filtered."}
+ )
+ filter_proba: Optional[float] = field(
+ default=0.7, metadata={"help": "Probability for filtering config, test and uncommon files."}
+ )
+ tokenizer: Optional[str] = field(
+ default="lvwerra/codeparrot",
+ metadata={"help": "Name or path to the tokenizer."},
+ )
@dataclass
@@ -157,14 +166,13 @@ class TokenizerTrainingArguments:
"""
base_tokenizer: Optional[str] = field(
- default="gpt2",
- metadata={"help": "Base tokenizer to build new tokenizer from."},
+ default="gpt2", metadata={"help": "Base tokenizer to build new tokenizer from."}
)
dataset_name: Optional[str] = field(
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
)
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
- vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
+ vocab_size: Optional[int] = field(default=200_000, metadata={"help": "Number of examples to train tokenizer on."})
n_examples: Optional[int] = field(
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
)
@@ -172,6 +180,24 @@ class TokenizerTrainingArguments:
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
+@dataclass
+class PretokenizationArguments:
+ """
+ Configuration for data pretokenization.
+ """
+
+ tokenizer_dir: Optional[str] = field(
+ default="lvwerra/codeparrot", metadata={"help": "Name or path to the tokenizer."}
+ )
+ dataset_name: Optional[str] = field(
+ default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path to the dataset to pretokenize."}
+ )
+ tokenized_data_repo: Optional[str] = field(
+ default="tokenized-codeparrot-train", metadata={"help": "Repo name of the pretokenized data."}
+ )
+ num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
+
+
@dataclass
class InitializationArguments:
"""
@@ -179,8 +205,7 @@ class InitializationArguments:
"""
config_name: Optional[str] = field(
- default="gpt2-large",
- metadata={"help": "Configuration to use for model initialization."},
+ default="gpt2-large", metadata={"help": "Configuration to use for model initialization."}
)
tokenizer_name: Optional[str] = field(
default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}
diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py
index b00afac7508f4d..b2af8767a217a6 100644
--- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py
+++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py
@@ -7,14 +7,16 @@
import datasets
import torch
from datasets import load_dataset
+from torch.optim import AdamW
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
import transformers
from accelerate import Accelerator, DistributedType
from arguments import TrainingArguments
from huggingface_hub import Repository
-from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
+from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
class ConstantLengthDataset(IterableDataset):
@@ -25,21 +27,36 @@ class ConstantLengthDataset(IterableDataset):
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
- num_of_sequences: Number of token sequences to keep in buffer.
- chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
+ num_of_sequences (int): Number of token sequences to keep in buffer.
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
+ tokenized (bool): If true we use a pretokenized dataset.
"""
def __init__(
- self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
+ self,
+ tokenizer,
+ dataset,
+ infinite=False,
+ seq_length=1024,
+ num_of_sequences=1024,
+ chars_per_token=3.6,
+ tokenized=False,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
- self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite
self.current_size = 0
+ self.tokenized = tokenized
+
+ if self.tokenized:
+ self.max_buffer_size = seq_length * num_of_sequences
+ self.content_field = "input_ids"
+ else:
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
+ self.content_field = "content"
def __iter__(self):
iterator = iter(self.dataset)
@@ -47,10 +64,10 @@ def __iter__(self):
while more_examples:
buffer, buffer_len = [], 0
while True:
- if buffer_len >= self.input_characters:
+ if buffer_len >= self.max_buffer_size:
break
try:
- buffer.append(next(iterator)["content"])
+ buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
@@ -60,7 +77,10 @@ def __iter__(self):
else:
more_examples = False
break
- tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
+ if self.tokenized:
+ tokenized_inputs = buffer
+ else:
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
@@ -70,6 +90,9 @@ def __iter__(self):
self.current_size += 1
yield torch.tensor(input_ids)
+ def shuffle(self, buffer_size=1000):
+ return ShufflerIterDataPipe(self, buffer_size=buffer_size)
+
def setup_logging(args):
project_name = args.model_ckpt.split("/")[-1]
@@ -102,14 +125,19 @@ def create_dataloaders(args):
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
- train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
- valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
- train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
+ train_dataset = ConstantLengthDataset(
+ tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
+ )
+ valid_dataset = ConstantLengthDataset(
+ tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
+ )
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer)
+ train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader
-def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
+def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
@@ -162,14 +190,14 @@ def evaluate(args):
return loss.item(), perplexity.item()
-# Accelerator
-accelerator = Accelerator(log_with=["wandb", "tensorboard"])
-acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
-
# Settings
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args()
+# Accelerator
+accelerator = Accelerator(log_with=["wandb", "tensorboard"], logging_dir=f"{args.save_dir}/log")
+acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
+
args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
set_seed(args.seed)
@@ -234,13 +262,14 @@ def get_lr():
model.train()
completed_steps = 0
t_start = time.time()
+loss_tracking = 0
for step, batch in enumerate(train_dataloader, start=1):
if args.resume_from_checkpoint and step < resume_step:
continue # we need to skip steps until we reach the resumed step
loss = model(batch, labels=batch, use_cache=False).loss
- log_metrics(
- step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
- )
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ loss_tracking += avg_loss.item() / args.gradient_accumulation_steps
+ log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
loss = loss / args.gradient_accumulation_steps
if step % args.gradient_accumulation_steps != 0:
# Prevent backward from doing gradient all_reduce in every step
@@ -250,16 +279,27 @@ def get_lr():
else:
accelerator.backward(loss)
else:
+ lr = get_lr()
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- completed_steps += 1
elapsed_time = time.time() - t_start
tflops = compute_tflops(elapsed_time, accelerator, args)
- log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
+ log_metrics(
+ step,
+ {
+ "steps": completed_steps,
+ "loss/train": loss_tracking,
+ "lr": lr,
+ "tflops": tflops,
+ "time_per_iteration": elapsed_time,
+ },
+ )
t_start = time.time()
+ loss_tracking = 0
+ completed_steps += 1
if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args)
diff --git a/examples/research_projects/codeparrot/scripts/human_eval.py b/examples/research_projects/codeparrot/scripts/human_eval.py
index 1eb5555cd79c4c..d0614134ad4732 100644
--- a/examples/research_projects/codeparrot/scripts/human_eval.py
+++ b/examples/research_projects/codeparrot/scripts/human_eval.py
@@ -186,7 +186,8 @@ def main():
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
except ValueError as exception:
print(
- 'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.'
+ 'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"`'
+ " flag to enable code evaluation."
)
raise exception
diff --git a/examples/research_projects/codeparrot/scripts/initialize_model.py b/examples/research_projects/codeparrot/scripts/initialize_model.py
index 8654ccc9062267..9d066b19087396 100644
--- a/examples/research_projects/codeparrot/scripts/initialize_model.py
+++ b/examples/research_projects/codeparrot/scripts/initialize_model.py
@@ -10,13 +10,17 @@
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
-config_kwargs = {"vocab_size": len(tokenizer), "scale_attn_by_layer_idx": True, "reorder_and_upcast_attn": True}
+config_kwargs = {
+ "vocab_size": len(tokenizer),
+ "scale_attn_by_inverse_layer_idx": True,
+ "reorder_and_upcast_attn": True,
+}
# Load model config (GPT-2 large in this case)
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
# Initialize new model with config
-model = AutoModelForCausalLM(config)
+model = AutoModelForCausalLM.from_config(config)
# Save model to the hub
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
diff --git a/examples/research_projects/codeparrot/scripts/preprocessing.py b/examples/research_projects/codeparrot/scripts/preprocessing.py
index bb037750a60cad..0e5899f5de9a16 100644
--- a/examples/research_projects/codeparrot/scripts/preprocessing.py
+++ b/examples/research_projects/codeparrot/scripts/preprocessing.py
@@ -1,4 +1,5 @@
import gzip
+import hashlib
import multiprocessing
import os
import shutil
@@ -8,12 +9,12 @@
from datasets import load_dataset
from arguments import PreprocessingArguments
-from transformers import HfArgumentParser
+from transformers import AutoTokenizer, HfArgumentParser
def get_hash(example):
"""Get hash of content field."""
- return {"hash": hash(example["content"])}
+ return {"hash": hashlib.md5(example["content"].strip().encode("utf-8")).hexdigest()}
def line_stats(example):
@@ -49,18 +50,77 @@ def is_autogenerated(example, scan_width=5):
return {"autogenerated": False}
+def is_config_or_test(example, scan_width=5, coeff=0.05):
+ """Check if file is a configuration file or a unit test by :
+ 1- looking for keywords in the first few lines of the file.
+ 2- counting number of occurence of the words 'config' and 'test' with respect to number of lines.
+ """
+
+ keywords = ["unit tests", "test file", "configuration file"]
+ lines = example["content"].splitlines()
+ count_config = 0
+ count_test = 0
+ # first test
+ for _, line in zip(range(scan_width), lines):
+ for keyword in keywords:
+ if keyword in line.lower():
+ return {"config_or_test": True}
+ # second test
+ nlines = example["content"].count("\n")
+ threshold = int(coeff * nlines)
+ for line in lines:
+ count_config += line.lower().count("config")
+ count_test += line.lower().count("test")
+ if count_config > threshold or count_test > threshold:
+ return {"config_or_test": True}
+ return {"config_or_test": False}
+
+
+def has_no_keywords(example):
+ """Check if a python file has none of the keywords for: funcion, class, for loop, while loop."""
+ keywords = ["def ", "class ", "for ", "while "]
+ lines = example["content"].splitlines()
+ for line in lines:
+ for keyword in keywords:
+ if keyword in line.lower():
+ return {"has_no_keywords": False}
+ return {"has_no_keywords": True}
+
+
+def has_few_assignments(example, minimum=4):
+ """Check if file uses symbol '=' less than `minimum` times."""
+ lines = example["content"].splitlines()
+ counter = 0
+ for line in lines:
+ counter += line.lower().count("=")
+ if counter > minimum:
+ return {"has_few_assignments": False}
+ return {"has_few_assignments": True}
+
+
+def char_token_ratio(example):
+ """Compute character/token ratio of the file with tokenizer."""
+ input_ids = tokenizer(example["content"], truncation=False)["input_ids"]
+ ratio = len(example["content"]) / len(input_ids)
+ return {"ratio": ratio}
+
+
def preprocess(example):
"""Chain all preprocessing steps into one function to not fill cache."""
results = dict()
results.update(get_hash(example))
results.update(line_stats(example))
results.update(alpha_stats(example))
+ results.update(char_token_ratio(example))
results.update(is_autogenerated(example))
+ results.update(is_config_or_test(example))
+ results.update(has_no_keywords(example))
+ results.update(has_few_assignments(example))
return results
def filter(example, uniques, args):
- """Filter dataset with heuristics."""
+ """Filter dataset with heuristics. Config, test and has_no_keywords files are removed with a given probability."""
if not check_uniques(example, uniques):
return False
elif example["autogenerated"]:
@@ -71,6 +131,14 @@ def filter(example, uniques, args):
return False
elif example["alpha_frac"] < args.alpha_frac:
return False
+ elif example["ratio"] < args.min_token_ratio:
+ return False
+ elif example["config_or_test"] and np.random.rand() <= args.filter_proba:
+ return False
+ elif example["has_no_keywords"] and np.random.rand() <= args.filter_proba:
+ return False
+ elif example["has_few_assignments"]:
+ return False
else:
return True
@@ -88,6 +156,7 @@ def compress_file(file_path):
args = parser.parse_args()
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
+tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
# Load dataset
t_start = time.time()
diff --git a/examples/research_projects/codeparrot/scripts/pretokenizing.py b/examples/research_projects/codeparrot/scripts/pretokenizing.py
new file mode 100644
index 00000000000000..9ebe1e577ddefa
--- /dev/null
+++ b/examples/research_projects/codeparrot/scripts/pretokenizing.py
@@ -0,0 +1,49 @@
+import multiprocessing
+import time
+
+from datasets import load_dataset
+
+from arguments import PretokenizationArguments
+from transformers import AutoTokenizer, HfArgumentParser
+
+
+def tokenize(example):
+ output = dict()
+ output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
+ output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
+ return output
+
+
+parser = HfArgumentParser(PretokenizationArguments)
+args = parser.parse_args()
+if args.num_workers is None:
+ args.num_workers = multiprocessing.cpu_count()
+tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
+
+t_start = time.time()
+ds = load_dataset(args.dataset_name, split="train")
+print(f"Dataset loaded in {time.time()-t_start:.2f}s")
+
+t_start = time.time()
+ds = ds.map(
+ tokenize,
+ num_proc=args.num_workers,
+ remove_columns=[
+ "repo_name",
+ "path",
+ "copies",
+ "size",
+ "content",
+ "license",
+ "hash",
+ "line_mean",
+ "line_max",
+ "alpha_frac",
+ "autogenerated",
+ ],
+)
+print(f"Dataset tokenized in {time.time()-t_start:.2f}s")
+
+t_start = time.time()
+ds.push_to_hub(args.tokenized_data_repo)
+print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")
diff --git a/examples/research_projects/decision_transformer/requirements.txt b/examples/research_projects/decision_transformer/requirements.txt
index 4924f4b513d2ba..bf3dd4f1777f7c 100644
--- a/examples/research_projects/decision_transformer/requirements.txt
+++ b/examples/research_projects/decision_transformer/requirements.txt
@@ -33,7 +33,7 @@ cmaes==0.8.2
cmd2==2.4.0
codecarbon==1.2.0
colorlog==6.6.0
-cookiecutter==1.7.2
+cookiecutter==2.1.1
cryptography==36.0.2
csvw==2.0.0
cycler==0.11.0
@@ -205,7 +205,7 @@ tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5
-tensorflow==2.8.0
+tensorflow==2.8.1
tensorflow-io-gcs-filesystem==0.24.0
termcolor==1.1.0
text-unidecode==1.3
diff --git a/examples/research_projects/deebert/run_glue_deebert.py b/examples/research_projects/deebert/run_glue_deebert.py
index 5bfc2f8816dcad..f86390375ff754 100644
--- a/examples/research_projects/deebert/run_glue_deebert.py
+++ b/examples/research_projects/deebert/run_glue_deebert.py
@@ -459,8 +459,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -529,8 +531,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
diff --git a/examples/research_projects/distillation/grouped_batch_sampler.py b/examples/research_projects/distillation/grouped_batch_sampler.py
index 6c2d9b974886c8..83addc371f2e21 100644
--- a/examples/research_projects/distillation/grouped_batch_sampler.py
+++ b/examples/research_projects/distillation/grouped_batch_sampler.py
@@ -60,7 +60,7 @@ class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
- "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ "sampler should be an instance of torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
diff --git a/examples/research_projects/distillation/run_squad_w_distillation.py b/examples/research_projects/distillation/run_squad_w_distillation.py
index ea1f2f46a9697d..3acfd468640626 100644
--- a/examples/research_projects/distillation/run_squad_w_distillation.py
+++ b/examples/research_projects/distillation/run_squad_w_distillation.py
@@ -518,7 +518,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -590,8 +593,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -603,8 +608,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -649,14 +656,18 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
@@ -685,8 +696,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/research_projects/distillation/scripts/extract.py b/examples/research_projects/distillation/scripts/extract.py
index d7a99b1d89d0da..f60f243dece6c6 100644
--- a/examples/research_projects/distillation/scripts/extract.py
+++ b/examples/research_projects/distillation/scripts/extract.py
@@ -25,7 +25,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
+ description=(
+ "Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned"
+ " Distillation"
+ )
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default="roberta-large", type=str)
diff --git a/examples/research_projects/distillation/scripts/extract_distilbert.py b/examples/research_projects/distillation/scripts/extract_distilbert.py
index e125f36187cd8a..a58105f999e827 100644
--- a/examples/research_projects/distillation/scripts/extract_distilbert.py
+++ b/examples/research_projects/distillation/scripts/extract_distilbert.py
@@ -25,7 +25,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(
- description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
+ description=(
+ "Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned"
+ " Distillation"
+ )
)
parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
diff --git a/examples/research_projects/distillation/train.py b/examples/research_projects/distillation/train.py
index 6385c885a96e14..cc2362888e4725 100644
--- a/examples/research_projects/distillation/train.py
+++ b/examples/research_projects/distillation/train.py
@@ -207,8 +207,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
@@ -226,8 +228,8 @@ def main():
if os.path.exists(args.dump_path):
if not args.force:
raise ValueError(
- f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
- "Use `--force` if you want to overwrite it"
+ f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
+ " itUse `--force` if you want to overwrite it"
)
else:
shutil.rmtree(args.dump_path)
diff --git a/examples/research_projects/fsner/src/fsner/tokenizer_utils.py b/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
index 6e4027a9891d7a..bc5f6650ccd9f5 100644
--- a/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
+++ b/examples/research_projects/fsner/src/fsner/tokenizer_utils.py
@@ -48,7 +48,8 @@ def tokenize(self, x):
else:
raise Exception(
- "Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of strings` for supports are supported."
+ "Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of"
+ " strings` for supports are supported."
)
return d
diff --git a/examples/research_projects/information-gain-filtration/README.md b/examples/research_projects/information-gain-filtration/README.md
new file mode 100644
index 00000000000000..bf95cb8ea81423
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/README.md
@@ -0,0 +1,100 @@
+
+# Information Gain Filtration(IGF)
+
+Authors @Tuko @mraunak
+
+This folder contains the code how to implement IGF for finetuning on GPT-2.
+
+## What is IGF?
+
+Here we present a general fine-tuning method that we call information gain filtration for improving the overall training efficiency and final
+performance of language model fine-tuning(see paper below). The method is an alternative fine-tuning method that trains
+a secondary model (e.g., a simple convolutional network) to predict the amount of information
+gained over a given pre-trained model. The secondary model is lightweight and trained to
+predict the Information Gain measure. Information Gain is defined as the change in a loss
+function for a model before and after an SGD update with a sample (Equation X in the paper).
+A small subset of the training set named the āobjectiveā set, is used to measure information
+gain on the pre-trained model, and consequently to train the secondary model. After
+training, the model is used for filtering samples for the fine-tuning process. Therefore,
+a high information gain value would suggest a sample is informative, whereas a low value
+would suggest a non-informative sample that should be filtered out. Thus, a thresholding
+strategy is defined to select informative samples. With such a strategy, samples are filtered
+and once enough samples are selected to form a mini-batch and a usual fine-tuning/optimization
+step is applied. The filtration process is repeated until the fine-tuning process is over.
+
+Paper [Selecting Informative Contexts Improves Language Model Finetuning](https://arxiv.org/abs/2005.00175)
+
+# Results
+
+Several experiments were conducted to show the robustness of the IGF method versus the
+standard fine-tuning process. For example, we achieve a median perplexity of 54.0 on the
+Books dataset compared to 57.3 for standard fine-tuning on GPT-2 Small. The code was
+implemented using the Transformers library and Pytorch. While the method may seem more
+expensive, we saw enough evidence that it may lead to a performance benefit in the final models.
+
+![IGF performance](result_igf.png)
+
+Figure 1: Comparing IGF to Standard Fine-tuning:
+IGF with constant (p < 10ā3 , t-test) and shifting(p < 10ā6 , t-test) thresholding significantly outperform standard fine-tuning. The left-hand figure shows
+test-set perplexity after each fine-tuning batch, averaged over 50 runs (error bars denote Ā± one standard error). The right-hand figure shows the perplexity of each
+method after 60 batches. IGF with shifting thresholding (red) clearly improves over standard batched fine-tuning with Adam
+
+## How to use this project?
+
+To fine-tune a transformer model with IGF on a language modeling task, use the following script:
+
+- `model_name_or_path`: Path to pretrained model or model identifier from huggingface.co/models
+- `data_file`: A jbl file containing tokenized data which can be split as objective dataset,
+ train_dataset and test_dataset
+- `igf_data_file`: A jbl file containing the context and information gain pairs to train secondary learner.
+- `context_len`: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded.
+- `size_objective_set`: Number of articles that are long enough to be used as our objective set"
+- `min_len`: The minimum length of the article to be used as objective set
+- `trim`: Truncate the example if it exceeds context length
+- `eval_freq`: Secondary model evaluation can be triggered at eval_freq
+- `max_steps`: To calculate training epochs
+- `number`: The number of examples split to be used as objective_set/test_data
+- `secondary_learner_batch_size`: The batch size of training data for secondary learner
+- `secondary_learner_max_epochs`: The number of epochs to train secondary learner
+- `recopy_model`: Reset the model to the original pretrained GPT-2 weights after each iteration
+- `eval_interval`: Decay the selectivity of our secondary learner filter from"
+ 1 standard deviation above average to 1 below average after eval_interval(10) batches"
+
+
+```python
+python run_clm_igf.py\
+--model_name_or_path "gpt2" \
+--data_file="data/tokenized_stories_train_wikitext103" \
+--igf_data_file="data/IGF_values" \
+--context_len 32 \
+--size_objective_set 100 \
+--min_len 1026 \
+--trim True \
+--eval_freq 100 \
+--max_steps 1000 \
+--secondary_learner_batch_size 128 \
+--secondary_learner_max_epochs 15 \
+--number 100 \
+--recopy_model \
+--eval_interval 10 \
+```
+
+## Citation
+
+If you find the resource useful, please cite the following paper
+
+```
+@inproceedings{antonello-etal-2021-selecting,
+ title = "Selecting Informative Contexts Improves Language Model Fine-tuning",
+ author = "Antonello, Richard and Beckage, Nicole and Turek, Javier and Huth, Alexander",
+ booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
+ month = aug,
+ year = "2021",
+ address = "Online",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/2021.acl-long.87",
+ doi = "10.18653/v1/2021.acl-long.87",
+ pages = "1072--1085",
+}
+```
diff --git a/tests/albert/__init__.py b/examples/research_projects/information-gain-filtration/igf/__init__.py
similarity index 100%
rename from tests/albert/__init__.py
rename to examples/research_projects/information-gain-filtration/igf/__init__.py
diff --git a/examples/research_projects/information-gain-filtration/igf/igf.py b/examples/research_projects/information-gain-filtration/igf/igf.py
new file mode 100644
index 00000000000000..99bd8c2d06d71c
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/igf/igf.py
@@ -0,0 +1,419 @@
+# Copyright 2022 - Intel Corp. All rights reserved.
+# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Backage
+
+import copy
+import logging
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import joblib
+from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
+
+
+logger = logging.getLogger(__name__)
+
+
+def set_seed(seed):
+ """
+ For reproducible training
+
+ Args:
+ seed: A seed for reproducible training
+
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def compute_perplexity(model, test_data, context_len):
+ """
+ Computes perplexity of the transformer model on data in test_data
+
+ Args:
+ model: Pre-trained GPT2 model
+ test_data: Data on which perplexity calculation is required
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+
+ Returns:
+ Perplexity on input test data
+
+ """
+
+ model.eval()
+ device = next(model.parameters()).device
+ eval_batch_size = 1
+ context = torch.zeros((eval_batch_size, context_len), dtype=torch.long, device=device)
+ eval_dataloader = DataLoader(test_data, shuffle=False, batch_size=eval_batch_size)
+ eval_loss = torch.zeros(1, device=device)
+ nb_eval_examples = 0
+ for batch in eval_dataloader:
+ batch.to(device)
+ # pad
+ context.zero_()
+ for i in range(eval_batch_size):
+ context[i, :] = batch[i]
+ outputs = model(context, labels=context)
+ eval_loss += outputs[0].sum().item()
+ nb_eval_examples += batch.size(0)
+ eval_loss = eval_loss / nb_eval_examples
+ perplexity = torch.exp(eval_loss)
+ model.train()
+ return perplexity
+
+
+def load_gpt2(model_name="gpt2"):
+ """
+ load original gpt2 and save off for quicker loading
+
+ Args:
+ model_name: GPT-2
+
+ Returns:
+ GPT-2 model
+
+ """
+
+ model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
+ torch.save(model.state_dict(), model_name + "local.pt")
+ return model
+
+
+def recopy_gpt2(orig_model, device, max_steps):
+ """
+ Reset the model to the original pretrained GPT-2 weights after each iteration
+
+ Args:
+ orig_model: Original pretrained GPT-2 model imported from Transformers library
+ device: CPU/GPU
+ max_steps: number of training steps
+
+ Returns:
+ Original PreTrained GPT-2 model,
+ lm_optimizer: Adam optimizer with Decoupled weight decay
+ lm_scheduler: linear scheduler with the appropriate schedule
+
+ """
+ model = copy.deepcopy(orig_model)
+ model.to(device)
+
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
+ ]
+ lm_optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
+ lm_scheduler = get_linear_schedule_with_warmup(lm_optimizer, 0, max_steps)
+ torch.cuda.empty_cache()
+ return model, lm_optimizer, lm_scheduler
+
+
+def intermittent_save(contexts, real_perps, past_perps, filename):
+
+ """
+ save the perplexity differences to filename
+
+ Args:
+ contexts: Example on which the perplexity is calculated
+ real_perps: Perplexity after back-propagating on the selected context
+ past_perps: Perplexity of model before training on the context
+ filename: File to store perplexity differences
+
+ Returns:
+ file with perplexity differences
+
+ """
+ # save the perplexity differences to filename
+ avg = np.array(real_perps).mean()
+ std = np.array(real_perps).std()
+ perp_diff = (real_perps - avg) / std
+ data_final = list(zip(contexts, perp_diff, past_perps))
+ joblib.dump(data_final, filename)
+
+
+def collect_objective_set(
+ model,
+ orig_perp,
+ context_len,
+ train_data,
+ objective_set,
+ max_steps,
+ device,
+ filename="dev.jbl",
+ recopy_model=recopy_gpt2,
+):
+
+ """
+ Collect individual IGF values from pre-trained transformer model
+ max_steps samples of training data to train secondary model
+
+ Args:
+ model: Pre-trained GPT2 model
+ orig_perp: Perplexity of original pretrained GPT-2 model
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ train_data: Data to train model
+ objective_set: Contexts used to create (X,IG(X)) pairs which is the training data for secondary learner
+ max_steps: To calculate training epochs of model
+ device: GPU/CPU
+ filename: To store intermediate perplexity differences
+ recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
+
+ Returns:
+ file stored intermediate perplexity differences in intermediate stages
+
+ """
+
+ # initialize variables to record relevant information
+ contexts = []
+ real_perps = []
+ past_perps = []
+
+ # Initialize the transformer model
+ orig_model = copy.deepcopy(model)
+ orig_model.to(device="cpu")
+ torch.cuda.empty_cache()
+
+ # Compute perplexity of initial transformer model for comparison
+ model.train()
+ model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
+
+ for step in tqdm(range(max_steps)):
+ context = torch.zeros((1, context_len), dtype=torch.long, device=device)
+ story = random.choice(train_data)
+ start = random.randint(0, len(story[0]) - context_len - 1)
+ context[0, :] = story[0][start : start + context_len]
+ lm_optimizer.zero_grad()
+ outputs = model(context, labels=context)
+ lm_loss = outputs[0]
+ past_perp = compute_perplexity(model, context, context_len)
+ model.train()
+ lm_loss.backward()
+ # Do LM backprop
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
+ lm_optimizer.step()
+ lm_scheduler.step() # Update learning rate schedule
+
+ # Compute perplexity after back-propagating on the selected context
+ real_perp = compute_perplexity(model, objective_set, context_len)
+
+ # Periodically save the stored (X, IG(X)) pairs
+ if step % 1000 == 0 and step > 1:
+ intermittent_save(contexts, real_perps, past_perps, filename)
+
+ # Reset the pretrained model to the original pretrained GPT-2 weights after each iteration
+ model, lm_optimizer, lm_scheduler = recopy_model(orig_model, device, max_steps)
+
+ past_perps.append(past_perp.item())
+ real_perps.append(orig_perp - real_perp.item())
+ contexts.append(np.array(context.cpu()))
+
+ intermittent_save(contexts, real_perps, past_perps, filename)
+
+
+def generate_datasets(
+ context_len, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
+):
+ """
+ Generate objective set and training set
+
+ Args:
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ file: Tokenized data split into training set and objective set
+ number: size of objective dataset
+ min_len: minimum length of a context in objective set
+ trim: If True truncate the context if it exceeds context length
+
+ Returns:
+ Generated objective set and training data
+
+
+ """
+ # Generate objective set and training set
+ # Designate the first number (100) articles that are long enough to be used
+ # as our objective set, rest (that are long enough) are training data for
+ # secondary learner
+
+ data = joblib.load(file)
+ print("data loaded")
+ objective_set = []
+ if trim:
+ for i, example in enumerate(data):
+ if len(example[0]) > min_len:
+ start = random.randint(0, len(example[0]) - context_len - 1)
+ objective_set.append(example[0, start : start + context_len])
+ if len(objective_set) >= number:
+ break
+ train_data = []
+ for j in range(i + 1, len(data)):
+ if len(data[j][0]) > min_len:
+ train_data.append(data[j])
+ else:
+ objective_set = data[0:number]
+ train_data = data[number:]
+
+ joblib.dump(objective_set, "objective_set.jbl")
+ print("objective set saved")
+ return train_data, objective_set
+
+
+def train_secondary_learner(
+ secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
+):
+
+ """
+ Train the secondary learner (igf_model)
+
+ Args:
+ secondary_learner: secondary learner
+ train_dataset: data to train secondary learner
+ max_epochs: number of epochs to train secondary learner
+ batch_size: batch size of training data of secondary learner
+ eval_freq: secondary model evaluation can be triggered at eval_freq
+ igf_model_path: path to store trained secondary learner
+
+ Returns:
+ Trained secondary learner
+
+ """
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # We will use the first 512 pairs from our dataset as a test set for
+ # our secondary learner and the rest to train
+ test_dataset = train_dataset[:512]
+ train_dataset = train_dataset[512:]
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
+ test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
+
+ # secondary learner model set up
+ loss = nn.MSELoss()
+ test_loss = nn.MSELoss(reduction="sum")
+ secondary_learner.to(device)
+ q_optimizer = torch.optim.Adam(secondary_learner.parameters(), lr=0.00001)
+ secondary_learner.train()
+
+ # TODO in original code this is written as number of actual batches seen
+ # not number of items seen but other places it is number of items instead.
+ # improve consistency! changed this to epochs for clarity
+ best_test_loss = float("inf")
+ # Iterate through batches until we've used max_steps batches
+ for epoch in range(int(max_epochs)):
+ tr_q_loss = 0.0
+ secondary_learner.train()
+ for step, batch in enumerate(train_dataloader):
+ context = batch[0].to(device)
+ real_q = batch[1].to(device)
+ predicted_q = secondary_learner(context)
+ q_optimizer.zero_grad()
+ q_loss = loss(predicted_q, real_q.float())
+ q_loss.backward()
+ q_optimizer.step()
+ tr_q_loss += q_loss.item()
+
+ # model trains fairly quickly so we won't wait for a full epoch
+ # eval is triggered at eval_freq and end of epochs
+ if (step % eval_freq == 0 and step > 0) or ((step + 1) == len(train_dataloader)):
+ tr_loss = tr_q_loss / (step + 1)
+
+ secondary_learner.eval()
+ q_loss2 = 0.0
+ sum_q2 = 0.0
+ predicted = []
+ actual = []
+ # Compute performance of the secondary learner after this batch
+ for step2, batch2 in enumerate(test_dataloader):
+ features2 = batch2[0].to(device)
+ real_q2 = batch2[1].to(device)
+ predicted_q2 = secondary_learner(features2)
+ q_loss2 += test_loss(predicted_q2, real_q2).item()
+ sum_q2 += torch.sum(predicted_q2).item()
+ for ei, i in enumerate(predicted_q2.cpu().detach().numpy()):
+ predicted.append(i.item())
+ for ei, i in enumerate(real_q2.cpu().detach().numpy()):
+ actual.append(i.item())
+
+ q_loss2 /= len(test_dataset)
+ print(
+ "Epoch: ",
+ epoch,
+ "step: ",
+ step,
+ "Avg. q:",
+ sum_q2 / len(test_dataset),
+ "Train Loss: ",
+ tr_loss,
+ "Test Loss: ",
+ q_loss2,
+ )
+ if q_loss2 < best_test_loss:
+ joblib.dump((predicted, actual), "pred_vs_actual.jbl")
+ torch.save(secondary_learner.state_dict(), igf_model_path)
+ best_test_loss = q_loss2
+
+ secondary_learner.train()
+ return secondary_learner
+
+
+class SecondaryLearner(nn.Module):
+ """
+ Our secondary learner
+ """
+
+ def __init__(self, model):
+ """
+ We use a simple convolutional network as our secondary learner
+
+ Args:
+ model: Pre-trained GPT2 model
+ """
+ # embeddings are from the pretrained model
+ super(SecondaryLearner, self).__init__()
+ self.embeddings = model.transformer.wte
+ self.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
+ self.conv = nn.Conv1d(self.embeddings.weight.size(1), 256, 3, padding=1)
+ self.fc = nn.Sequential(nn.Linear(256, 32), nn.Dropout(p=0.1), nn.Linear(32, 32), nn.Linear(32, 1))
+
+ def forward(self, context):
+ """
+ Forward pass through the secondary learner
+
+ Args:
+ context: Context input to the secondary learner
+
+ Returns:
+ tensor after squeeze operation
+
+ """
+ pooled = torch.max(self.conv(self.embeddings(context).squeeze(1).transpose(1, 2)), 2)[0]
+ qs = self.fc(pooled)
+ return qs.squeeze(1)
+
+ @classmethod
+ def from_pretrained(cls, state_path, model):
+ """
+ Load the secondary learner
+
+ Args:
+ state_path: Path to save secondary learner
+ model: Pretrained GPT-2
+
+ Returns:
+ secondary learner
+ """
+
+ secondary_learner = cls(model) # this calls __init__
+ state_dict = torch.load(state_path)
+ secondary_learner.load_state_dict(state_dict)
+ secondary_learner.embeddings = model.transformer.wte
+ secondary_learner.embeddings.weight = copy.deepcopy(model.transformer.wte.weight)
+ return secondary_learner
diff --git a/examples/research_projects/information-gain-filtration/requirements.txt b/examples/research_projects/information-gain-filtration/requirements.txt
new file mode 100644
index 00000000000000..2aa3227637c888
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/requirements.txt
@@ -0,0 +1,6 @@
+matplotlib
+numpy>=1.17.2
+joblib>=0.13.2
+scipy
+torch>=1.10.1
+transformers>=3.5
\ No newline at end of file
diff --git a/examples/research_projects/information-gain-filtration/result_igf.png b/examples/research_projects/information-gain-filtration/result_igf.png
new file mode 100644
index 00000000000000..10bb0b7d681630
Binary files /dev/null and b/examples/research_projects/information-gain-filtration/result_igf.png differ
diff --git a/examples/research_projects/information-gain-filtration/run_clm_igf.py b/examples/research_projects/information-gain-filtration/run_clm_igf.py
new file mode 100644
index 00000000000000..eae10060b22fd1
--- /dev/null
+++ b/examples/research_projects/information-gain-filtration/run_clm_igf.py
@@ -0,0 +1,446 @@
+# Copyright 2022 - Intel Corp. All rights reserved.
+# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
+
+"""
+Implementation of a new method for fine-tuning transformer models that we call
+Information Gain Filtration 'IGF' on WikiText data set and compared the results
+with the standard fine-tuning method
+
+Steps followed in the code:
+
+1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
+Our IG (information gain) model is learning to predict the āinformativenessā of a particular
+context. Informativeness is the change in metric between the modelās accuracy on an
+objective set before and after seeing that context. For casual language modeling, the
+metric is perplexity.
+
+2) A secondary learner is trained to infer a function approximation for IG using the dataset
+created in (1).
+
+3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
+
+Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
+
+"""
+
+# Prerequisite libraries:
+
+import argparse
+import random
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, RandomSampler
+
+import joblib
+from igf.igf import (
+ SecondaryLearner,
+ collect_objective_set,
+ compute_perplexity,
+ generate_datasets,
+ load_gpt2,
+ recopy_gpt2,
+ set_seed,
+ train_secondary_learner,
+)
+from transformers import GPT2LMHeadModel
+
+
+def generate_n_pairs(
+ context_len=32,
+ max_steps=10,
+ size_objective_set=100,
+ min_len=1026,
+ trim=True,
+ data_file="data/tokenized_stories_train_wikitext103.jbl",
+ igf_data_file="igf_context_pairs.jbl",
+):
+
+ """
+ Collecting *n* pairs for training the secondary learner
+ Args:
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ max_steps: To calculate training epochs of secondary learner
+ size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
+ min_len: The minimum length of the article to be used as objective set
+ trim: If True truncate the context if it exceeds context length
+ data_file: Tokenized data set split for training and evaluation of model
+ igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
+
+ Returns:
+ Data stored in igf_data_file
+
+ """
+ # generates same data everytime
+ set_seed(3)
+ # generate train_data and objective_set
+ train_data, objective_set = generate_datasets(
+ context_len, data_file, number=size_objective_set, min_len=1026, trim=True
+ )
+ # keeps model same across runs
+ set_seed(4)
+ # model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
+ # can we train on GPU?
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ # load pretrained model
+ model = load_gpt2("gpt2").to(device)
+ print("computing perplexity on objective set")
+ orig_perp = compute_perplexity(model, objective_set, context_len).item()
+ print("perplexity on objective set:", orig_perp)
+
+ # collect igf pairs and save to file demo.jbl
+ collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
+
+ # clean up, delete model and data we don't need anymore
+ del model, train_data, objective_set
+ torch.cuda.empty_cache()
+
+
+def training_secondary_learner(
+ secondary_learner_train_data,
+ secondary_learner_max_epochs=15,
+ secondary_learner_batch_size=128,
+ eval_freq=100,
+ igf_model_path="igf_model.pt",
+):
+ """
+ Train the secondary learner
+
+ Args:
+ secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
+ secondary_learner_max_epochs: Number of epochs to train secondary learner
+ secondary_learner_batch_size: Batch size to train secondary learner
+ eval_freq (object): secondary model evaluation can be triggered at eval_freq
+ igf_model_path: path to store trained secondary learner
+
+ Returns:
+ Trained secondary learner
+ """
+
+ set_seed(42)
+
+ # Load pre-trained model
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
+
+ # Initialize secondary learner to use embedding weights of model
+ secondary_learner = SecondaryLearner(model)
+
+ # Train secondary learner
+ secondary_learner = train_secondary_learner(
+ secondary_learner,
+ secondary_learner_train_data,
+ max_epochs=secondary_learner_max_epochs,
+ batch_size=secondary_learner_batch_size,
+ eval_freq=100,
+ igf_model_path=igf_model_path,
+ )
+
+ del model, secondary_learner_train_data
+ torch.cuda.empty_cache()
+
+ return secondary_learner
+
+
+def finetune(
+ model,
+ train_dataset,
+ test_dataset,
+ context_len=32,
+ max_steps=1000,
+ batch_size=16,
+ threshold=1.0,
+ recopy_model=recopy_gpt2,
+ secondary_learner=None,
+ eval_interval=10,
+ finetuned_model_name="gpt2_finetuned.pt",
+):
+ """
+ fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
+
+ Args:
+ model: pre-trained GPT-2 model
+ train_dataset: Data set to train GPT-2 model
+ test_dataset: Evaluate GPT-2 model
+ context_len: The maximum total input sequence length after tokenization. Sequences longer
+ than this will be truncated, sequences shorter will be padded
+ max_steps: To calculate training epochs
+ batch_size: Batch size to train GPT-2 model
+ threshold: The threshold value used by secondary learner to filter the train_data and allow only"
+ informative data as input to the model
+ recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
+ secondary_learner: Selection of IGF as fine-tuning method if not None
+ eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
+ 1 standard deviation above average to 1 below average
+ fine-tuned_model_name: name of the final final-tuned GPT-2 model
+
+ Returns:
+ Fine-tuned GPT-2 model
+
+ """
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ train_sampler = RandomSampler(train_dataset)
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
+
+ num_train_epochs = max_steps // (len(train_dataset)) + 1
+ global_step = 0
+ context = torch.zeros((1, context_len), dtype=torch.long, device=device)
+ model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
+
+ model.train()
+ if secondary_learner is not None:
+ secondary_learner.to(device)
+ secondary_learner.eval()
+ contexts = []
+ examples = 0
+
+ observed_qs = []
+ test_perps = []
+
+ # Compute the performance of the transformer model at the beginning
+ real_perp = compute_perplexity(model, test_dataset, context_len)
+ test_perps.append(real_perp)
+ print("Test perplexity, step", global_step, ":", real_perp)
+ for epoch in range(int(num_train_epochs)):
+ for step, example in enumerate(train_dataloader):
+ torch.cuda.empty_cache()
+ start = random.randint(0, example.size(2) - context_len - 1)
+ context[0, :] = example[0, 0, start : start + context_len]
+ lm_optimizer.zero_grad()
+ outputs = model(context, labels=context)
+ do_backprop = True
+
+ if secondary_learner is not None:
+ predicted_q = secondary_learner.forward(
+ torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
+ )[0].item()
+ observed_qs.append(float(predicted_q))
+
+ # Here we implement the simple non-constant threshold for the predicted IG(X) value
+ # We will decay the selectivity of our secondary learner filter from
+ # 1 standard deviation above average to 1 below average after 10 batches.
+
+ if global_step == 10:
+ threshold = -1
+ if predicted_q < threshold:
+ do_backprop = False
+
+ # If we passed the filter, add the context to the batch!
+ if do_backprop:
+ contexts.append(np.array(context.cpu()))
+ lm_loss = outputs[0]
+ lm_loss.backward()
+ examples += 1
+
+ del outputs
+
+ # Once the batch is filled with enough contexts, backprop on the batch.
+ if examples == batch_size:
+ torch.cuda.empty_cache()
+ examples = 0
+ # Do LM backprop
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
+ lm_optimizer.step()
+ lm_scheduler.step() # Update learning rate schedule
+ global_step += 1
+ # Compute the performance of the transformer model at this batch
+ if global_step % eval_interval == 0:
+ real_perp = compute_perplexity(model, test_dataset, context_len)
+ test_perps.append(real_perp)
+
+ print("Test perplexity, step", global_step, ":", real_perp)
+ # Break out of the loop after 60 batches
+ if max_steps > 0 and global_step > 60:
+ break
+ if max_steps > 0 and global_step > 60:
+ break
+
+ # save finetuned transformer model
+ torch.save(model.state_dict(), finetuned_model_name)
+ torch.cuda.empty_cache()
+ # Do some cleaning up so we can reinitialize for the next run of this function
+ del lm_optimizer
+ del lm_scheduler
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
+
+ # Required parameters
+ parser.add_argument(
+ "--data_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The input data dir. Should contain data files for WikiText.",
+ )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models",
+ )
+ parser.add_argument(
+ "--data_file",
+ type=str,
+ default=None,
+ help=(
+ "A jbl file containing tokenized data which can be split as objective dataset, "
+ "train_dataset and test_dataset."
+ ),
+ )
+
+ parser.add_argument(
+ "--igf_data_file",
+ type=str,
+ default=None,
+ help="A jbl file containing the context and information gain pairs to train secondary learner.",
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the final fine-tuned model is stored.",
+ )
+
+ parser.add_argument(
+ "--tokenizer_name",
+ default=None,
+ type=str,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+
+ parser.add_argument(
+ "--context_len",
+ default=32,
+ type=int,
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
+ )
+
+ parser.add_argument(
+ "--size_objective_set",
+ default=100,
+ type=int,
+ help="number of articles that are long enough to be used as our objective set",
+ )
+ parser.add_argument(
+ "--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
+ )
+
+ parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
+
+ parser.add_argument(
+ "--secondary_learner_batch_size",
+ default=128,
+ type=int,
+ help="batch size of training data for secondary learner",
+ )
+
+ parser.add_argument(
+ "--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) "
+ )
+
+ parser.add_argument(
+ "--eval_interval",
+ default=10,
+ type=int,
+ help=(
+ "decay the selectivity of our secondary learner filter from"
+ "1 standard deviation above average to 1 below average after 10 batches"
+ ),
+ )
+
+ parser.add_argument(
+ "--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
+ )
+
+ parser.add_argument(
+ "--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
+ )
+
+ parser.add_argument(
+ "--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
+ )
+
+ parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
+
+ parser.add_argument(
+ "--threshold",
+ default=1.0,
+ type=float,
+ help=(
+ "The threshold value used by secondary learner to filter the train_data and allow only"
+ " informative data as input to the model"
+ ),
+ )
+
+ parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name")
+
+ parser.add_argument(
+ "--recopy_model",
+ default=recopy_gpt2,
+ type=str,
+ help="Reset the model to the original pretrained GPT-2 weights after each iteration",
+ )
+
+ # function calls
+ # Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
+ generate_n_pairs(
+ context_len=32,
+ max_steps=10,
+ size_objective_set=100,
+ min_len=1026,
+ trim=True,
+ data_file="data/tokenized_stories_train_wikitext103.jbl",
+ igf_data_file="igf_context_pairs.jbl",
+ )
+
+ # Load train data for secondary learner
+ secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
+
+ # Train secondary learner
+ secondary_learner = training_secondary_learner(
+ secondary_learner_train_data,
+ secondary_learner_max_epochs=15,
+ secondary_learner_batch_size=128,
+ eval_freq=100,
+ igf_model_path="igf_model.pt",
+ )
+
+ # load pretrained gpt2 model
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
+ set_seed(42)
+
+ # Generate train and test data to train and evaluate gpt2 model
+ train_dataset, test_dataset = generate_datasets(
+ context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
+ )
+
+ # fine-tuning of the gpt2 model using igf (Information Gain Filtration)
+ finetune(
+ model,
+ train_dataset,
+ test_dataset,
+ context_len=32,
+ max_steps=1000,
+ batch_size=16,
+ threshold=1.0,
+ recopy_model=recopy_gpt2,
+ secondary_learner=secondary_learner,
+ eval_interval=10,
+ finetuned_model_name="gpt2_finetuned.pt",
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
index 0bb4a7b9c5142b..f0f3e873d83f09 100755
--- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
+++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py
@@ -75,8 +75,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -99,7 +100,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -141,8 +145,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -155,8 +161,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
line_by_line: bool = field(
@@ -280,8 +288,10 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
tokenized_samples = next(train_iterator)
i += len(tokenized_samples["input_ids"])
- # concatenate tokenized samples to list
- samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
+ # concatenate tokenized samples to list (excluding "id" and "text")
+ samples = {
+ k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
+ }
# Concatenated tokens are split to lists of length `max_seq_length`.
# Note that remainedr of % max_seq_length are thrown away.
@@ -399,10 +409,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def tokenize_function(examples):
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
- tokenized_datasets = dataset.map(
- tokenize_function,
- batched=True,
- )
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
shuffle_seed = training_args.seed
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
@@ -575,7 +582,8 @@ def eval_step(params, batch):
if step % training_args.logging_steps == 0 and step > 0:
steps.write(
- f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
@@ -604,7 +612,10 @@ def eval_step(params, batch):
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
- steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
+ steps.desc = (
+ f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
+ f" {eval_metrics['accuracy']})"
+ )
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, step)
diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
index 0572a4e019a87f..6ee974666a291a 100644
--- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
+++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
@@ -77,14 +77,18 @@ class ModelArguments:
text_model_name_or_path: str = field(
metadata={
- "help": "The text model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The text model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ )
},
)
vision_model_name_or_path: str = field(
metadata={
- "help": "The vision model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The vision model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ )
},
)
from_pt: bool = field(
@@ -107,7 +111,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -129,22 +136,28 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=72,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -519,7 +532,8 @@ def eval_step(params, batch):
train_step_progress_bar.close()
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
index 3371dc3bd4df24..518ef9f7b22f3e 100644
--- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
+++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
@@ -69,8 +69,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -93,7 +94,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -118,15 +122,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
overwrite_cache: bool = field(
@@ -141,9 +149,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
overwrite_cache: bool = field(
@@ -334,7 +344,8 @@ def tokenize_function(examples):
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
)
return output
@@ -606,7 +617,8 @@ def eval_step(input_ids, labels, params):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
+ f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -632,7 +644,8 @@ def eval_step(input_ids, labels, params):
eval_metrics["perplexity"] = float("inf")
logger.info(
- f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
+ f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity:"
+ f" {eval_metrics['perplexity']}"
)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
index e2bcd7861beca0..b0600d978bd946 100755
--- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
+++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py
@@ -64,7 +64,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
+ "help": (
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
+ " `[float32, float16, bfloat16]`."
+ )
},
)
@@ -94,7 +97,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
speech_file_column: Optional[str] = field(
@@ -120,7 +125,10 @@ class DataTrainingArguments:
pad_to_multiple_of: Optional[int] = field(
default=1024,
metadata={
- "help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
+ "help": (
+ "If set will pad the sequence to a multiple of the provided value. This is important to avoid"
+ " triggering recompilations on TPU"
+ )
},
)
@@ -357,7 +365,8 @@ def normalize(batch):
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
@@ -557,7 +566,8 @@ def eval_step(params, batch):
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
- f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
+ f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
@@ -583,7 +593,8 @@ def eval_step(params, batch):
# Update progress bar
epochs.write(
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity:"
+ f" {eval_metrics['codevector_perplexity']})"
)
# Save metrics
diff --git a/examples/research_projects/layoutlmv3/README.md b/examples/research_projects/layoutlmv3/README.md
new file mode 100644
index 00000000000000..17bf4bb67cd90f
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/README.md
@@ -0,0 +1,69 @@
+
+
+# Token classification with LayoutLMv3 (PyTorch version)
+
+This directory contains a script, `run_funsd_cord.py`, that can be used to fine-tune (or evaluate) LayoutLMv3 on form understanding datasets, such as [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [CORD](https://github.com/clovaai/cord).
+
+The script `run_funsd_cord.py` leverages the š¤ Datasets library and the Trainer API. You can easily customize it to your needs.
+
+## Fine-tuning on FUNSD
+
+Fine-tuning LayoutLMv3 for token classification on [FUNSD](https://guillaumejaume.github.io/FUNSD/) can be done as follows:
+
+```bash
+python run_funsd_cord.py \
+ --model_name_or_path microsoft/layoutlmv3-base \
+ --dataset_name funsd \
+ --output_dir layoutlmv3-test \
+ --do_train \
+ --do_eval \
+ --max_steps 1000 \
+ --evaluation_strategy steps \
+ --eval_steps 100 \
+ --learning_rate 1e-5 \
+ --load_best_model_at_end \
+ --metric_for_best_model "eval_f1" \
+ --push_to_hub \
+ --push_to_hubĀ°model_id layoutlmv3-finetuned-funsd
+```
+
+š The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd. By specifying the `push_to_hub` flag, the model gets uploaded automatically to the hub (regularly), together with a model card, which includes metrics such as precision, recall and F1. Note that you can easily update the model card, as it's just a README file of the respective repo on the hub.
+
+There's also the "Training metrics" [tab](https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd/tensorboard), which shows Tensorboard logs over the course of training. Pretty neat, huh?
+
+## Fine-tuning on CORD
+
+Fine-tuning LayoutLMv3 for token classification on [CORD](https://github.com/clovaai/cord) can be done as follows:
+
+```bash
+python run_funsd_cord.py \
+ --model_name_or_path microsoft/layoutlmv3-base \
+ --dataset_name cord \
+ --output_dir layoutlmv3-test \
+ --do_train \
+ --do_eval \
+ --max_steps 1000 \
+ --evaluation_strategy steps \
+ --eval_steps 100 \
+ --learning_rate 5e-5 \
+ --load_best_model_at_end \
+ --metric_for_best_model "eval_f1" \
+ --push_to_hub \
+ --push_to_hubĀ°model_id layoutlmv3-finetuned-cord
+```
+
+š The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-cord. Note that a model card gets generated automatically in case you specify the `push_to_hub` flag.
\ No newline at end of file
diff --git a/examples/research_projects/layoutlmv3/requirements.txt b/examples/research_projects/layoutlmv3/requirements.txt
new file mode 100644
index 00000000000000..504a8cc9870fa0
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/requirements.txt
@@ -0,0 +1,2 @@
+datasets
+seqeval
\ No newline at end of file
diff --git a/examples/research_projects/layoutlmv3/run_funsd_cord.py b/examples/research_projects/layoutlmv3/run_funsd_cord.py
new file mode 100644
index 00000000000000..66be61dffccf20
--- /dev/null
+++ b/examples/research_projects/layoutlmv3/run_funsd_cord.py
@@ -0,0 +1,533 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fine-tuning LayoutLMv3 for token classification on FUNSD or CORD.
+"""
+# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
+# comments.
+
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from typing import Optional
+
+import datasets
+import numpy as np
+from datasets import ClassLabel, load_dataset, load_metric
+
+import transformers
+from transformers import (
+ AutoConfig,
+ AutoModelForTokenClassification,
+ AutoProcessor,
+ HfArgumentParser,
+ Trainer,
+ TrainingArguments,
+ set_seed,
+)
+from transformers.data.data_collator import default_data_collator
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version
+from transformers.utils.versions import require_version
+
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.19.0.dev0")
+
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+
+ model_name_or_path: str = field(
+ default="microsoft/layoutlmv3-base",
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
+ )
+ config_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+ )
+ processor_name: Optional[str] = field(
+ default=None, metadata={"help": "Name or path to the processor files if not the same as model_name"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ use_auth_token: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
+ },
+ )
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
+ dataset_name: Optional[str] = field(
+ default="nielsr/funsd-layoutlmv3",
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
+ )
+ train_file: Optional[str] = field(
+ default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
+ )
+ validation_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
+ )
+ test_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
+ )
+ text_column_name: Optional[str] = field(
+ default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
+ )
+ label_column_name: Optional[str] = field(
+ default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ )
+ preprocessing_num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of processes to use for the preprocessing."},
+ )
+ max_seq_length: int = field(
+ default=512,
+ metadata={
+ "help": (
+ "The maximum total input sequence length after tokenization. If set, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
+ },
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_predict_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
+ },
+ )
+ label_all_tokens: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
+ },
+ )
+ return_entity_level_metrics: bool = field(
+ default=False,
+ metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
+ )
+
+ def __post_init__(self):
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
+ raise ValueError("Need either a dataset name or a training/validation file.")
+ else:
+ if self.train_file is not None:
+ extension = self.train_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
+ if self.validation_file is not None:
+ extension = self.validation_file.split(".")[-1]
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
+ self.task_name = self.task_name.lower()
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Detecting last checkpoint.
+ last_checkpoint = None
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Get the datasets
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+ # download the dataset.
+ if data_args.dataset_name == "funsd":
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ "nielsr/funsd-layoutlmv3",
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ elif data_args.dataset_name == "cord":
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ "nielsr/cord-layoutlmv3",
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ raise ValueError("This script only supports either FUNSD or CORD out-of-the-box.")
+
+ if training_args.do_train:
+ column_names = dataset["train"].column_names
+ features = dataset["train"].features
+ else:
+ column_names = dataset["test"].column_names
+ features = dataset["test"].features
+
+ image_column_name = "image"
+ text_column_name = "words" if "words" in column_names else "tokens"
+ boxes_column_name = "bboxes"
+ label_column_name = (
+ f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
+ )
+
+ remove_columns = column_names
+
+ # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
+ # unique labels.
+ def get_label_list(labels):
+ unique_labels = set()
+ for label in labels:
+ unique_labels = unique_labels | set(label)
+ label_list = list(unique_labels)
+ label_list.sort()
+ return label_list
+
+ # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
+ # Otherwise, we have to get the list of labels manually.
+ if isinstance(features[label_column_name].feature, ClassLabel):
+ label_list = features[label_column_name].feature.names
+ # No need to convert the labels since they are already ints.
+ id2label = {k: v for k, v in enumerate(label_list)}
+ label2id = {v: k for k, v in enumerate(label_list)}
+ else:
+ label_list = get_label_list(datasets["train"][label_column_name])
+ id2label = {k: v for k, v in enumerate(label_list)}
+ label2id = {v: k for k, v in enumerate(label_list)}
+ num_labels = len(label_list)
+
+ # Load pretrained model and processor
+ #
+ # Distributed training:
+ # The .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+ config = AutoConfig.from_pretrained(
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+ num_labels=num_labels,
+ finetuning_task=data_args.task_name,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ processor = AutoProcessor.from_pretrained(
+ model_args.processor_name if model_args.processor_name else model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=True,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ add_prefix_space=True,
+ apply_ocr=False,
+ )
+
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ # Set the correspondences label/ID inside the model config
+ model.config.label2id = label2id
+ model.config.id2label = id2label
+
+ # Preprocessing the dataset
+ # The processor does everything for us (prepare the image using LayoutLMv3FeatureExtractor
+ # and prepare the words, boxes and word-level labels using LayoutLMv3TokenizerFast)
+ def prepare_examples(examples):
+ images = examples[image_column_name]
+ words = examples[text_column_name]
+ boxes = examples[boxes_column_name]
+ word_labels = examples[label_column_name]
+
+ encoding = processor(
+ images,
+ words,
+ boxes=boxes,
+ word_labels=word_labels,
+ truncation=True,
+ padding="max_length",
+ max_length=data_args.max_seq_length,
+ )
+
+ return encoding
+
+ if training_args.do_train:
+ if "train" not in dataset:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = dataset["train"]
+ if data_args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
+ with training_args.main_process_first(desc="train dataset map pre-processing"):
+ train_dataset = train_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ if training_args.do_eval:
+ validation_name = "test"
+ if validation_name not in dataset:
+ raise ValueError("--do_eval requires a validation dataset")
+ eval_dataset = dataset[validation_name]
+ if data_args.max_eval_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
+ with training_args.main_process_first(desc="validation dataset map pre-processing"):
+ eval_dataset = eval_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ if training_args.do_predict:
+ if "test" not in datasets:
+ raise ValueError("--do_predict requires a test dataset")
+ predict_dataset = datasets["test"]
+ if data_args.max_predict_samples is not None:
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
+ with training_args.main_process_first(desc="prediction dataset map pre-processing"):
+ predict_dataset = predict_dataset.map(
+ prepare_examples,
+ batched=True,
+ remove_columns=remove_columns,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ )
+
+ # Metrics
+ metric = load_metric("seqeval")
+
+ def compute_metrics(p):
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ results = metric.compute(predictions=true_predictions, references=true_labels)
+ if data_args.return_entity_level_metrics:
+ # Unpack nested dictionaries
+ final_results = {}
+ for key, value in results.items():
+ if isinstance(value, dict):
+ for n, v in value.items():
+ final_results[f"{key}_{n}"] = v
+ else:
+ final_results[key] = value
+ return final_results
+ else:
+ return {
+ "precision": results["overall_precision"],
+ "recall": results["overall_recall"],
+ "f1": results["overall_f1"],
+ "accuracy": results["overall_accuracy"],
+ }
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ tokenizer=processor,
+ data_collator=default_data_collator,
+ compute_metrics=compute_metrics,
+ )
+
+ # Training
+ if training_args.do_train:
+ checkpoint = None
+ if training_args.resume_from_checkpoint is not None:
+ checkpoint = training_args.resume_from_checkpoint
+ elif last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ metrics = train_result.metrics
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+ max_train_samples = (
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
+ )
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+
+ metrics = trainer.evaluate()
+
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ logger.info("*** Predict ***")
+
+ predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ trainer.log_metrics("predict", metrics)
+ trainer.save_metrics("predict", metrics)
+
+ # Save predictions
+ output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
+ if trainer.is_world_process_zero():
+ with open(output_predictions_file, "w") as writer:
+ for prediction in true_predictions:
+ writer.write(" ".join(prediction) + "\n")
+
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"}
+ if data_args.dataset_name is not None:
+ kwargs["dataset_tags"] = data_args.dataset_name
+ if data_args.dataset_config_name is not None:
+ kwargs["dataset_args"] = data_args.dataset_config_name
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
+ else:
+ kwargs["dataset"] = data_args.dataset_name
+
+ if training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(**kwargs)
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/longform-qa/eli5_utils.py b/examples/research_projects/longform-qa/eli5_utils.py
index ff72a16bfd235b..82c4bd8caf20d3 100644
--- a/examples/research_projects/longform-qa/eli5_utils.py
+++ b/examples/research_projects/longform-qa/eli5_utils.py
@@ -137,7 +137,7 @@ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_bat
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
- attention_mask, input_shape, device
+ attention_mask, input_shape
)
# define function for checkpointing
@@ -649,7 +649,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
" " + "
".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
- for (res_passages, dl) in zip(res_passages_lst, D):
+ for res_passages, dl in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc in zip(res_list, dl):
r["score"] = float(sc)
@@ -679,7 +679,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
"
" + "
".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
]
all_res_lists = []
- for (res_passages, dl, il) in zip(res_passages_lst, D, I):
+ for res_passages, dl, il in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i)
diff --git a/examples/research_projects/luke/README.md b/examples/research_projects/luke/README.md
index a4eb1370436b91..703eb0b4e4235c 100644
--- a/examples/research_projects/luke/README.md
+++ b/examples/research_projects/luke/README.md
@@ -14,7 +14,7 @@ the mean of the [š¤ `Accelerate`](https://github.com/huggingface/accelerate) l
after installing it:
```bash
-pip install accelerate
+pip install git+https://github.com/huggingface/accelerate
```
then to train English LUKE on CoNLL2003:
diff --git a/examples/research_projects/luke/run_luke_ner_no_trainer.py b/examples/research_projects/luke/run_luke_ner_no_trainer.py
index c7a9763d99659d..cb81402425ff2d 100644
--- a/examples/research_projects/luke/run_luke_ner_no_trainer.py
+++ b/examples/research_projects/luke/run_luke_ner_no_trainer.py
@@ -101,8 +101,8 @@ def parse_args():
type=int,
default=32,
help=(
- "The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_length` is passed."
+ "The maximum total input entity length after tokenization (Used only for (M)Luke models). Sequences longer"
+ " than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
@@ -110,8 +110,8 @@ def parse_args():
type=int,
default=30,
help=(
- "The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences longer than this will be truncated,"
- " sequences shorter will be padded if `--pad_to_max_length` is passed."
+ "The maximum total input mention length after tokenization (Used only for (M)Luke models). Sequences"
+ " longer than this will be truncated, sequences shorter will be padded if `--pad_to_max_length` is passed."
),
)
parser.add_argument(
diff --git a/examples/research_projects/lxmert/demo.ipynb b/examples/research_projects/lxmert/demo.ipynb
index 55658ae111e636..e80865d0e2c8f4 100644
--- a/examples/research_projects/lxmert/demo.ipynb
+++ b/examples/research_projects/lxmert/demo.ipynb
@@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
- "#%pip install-r requirements.txt"
+ "# %pip install-r requirements.txt"
]
},
{
diff --git a/examples/research_projects/lxmert/modeling_frcnn.py b/examples/research_projects/lxmert/modeling_frcnn.py
index 39a0c6aea8787d..33c1133e9589f4 100644
--- a/examples/research_projects/lxmert/modeling_frcnn.py
+++ b/examples/research_projects/lxmert/modeling_frcnn.py
@@ -592,7 +592,7 @@ def __call__(self, match_quality_matrix):
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
- for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
low_high = (matched_vals >= low) & (matched_vals < high)
match_labels[low_high] = l
@@ -1037,9 +1037,9 @@ def make_stage(
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
- assert len(v) == num_blocks, (
- f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
- )
+ assert (
+ len(v) == num_blocks
+ ), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
@@ -1401,7 +1401,7 @@ def num_cell_anchors(self):
def grid_anchors(self, grid_sizes):
anchors = []
- for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -1708,10 +1708,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
- assert (
- from_tf
- ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
- pretrained_model_name_or_path + ".index"
+ assert from_tf, (
+ "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
+ .format(pretrained_model_name_or_path + ".index")
)
archive_file = pretrained_model_name_or_path + ".index"
else:
@@ -1797,26 +1796,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if len(unexpected_keys) > 0:
print(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
- f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
- f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
- f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
print(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
- f"and are newly initialized: {missing_keys}\n"
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
print(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
- f"If your task is similar to the task the model of the checkpoint was trained on, "
- f"you can already use {model.__class__.__name__} for predictions without further training."
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
)
if len(error_msgs) > 0:
raise RuntimeError(
diff --git a/examples/research_projects/lxmert/requirements.txt b/examples/research_projects/lxmert/requirements.txt
index 9b3e500040688f..fc3b85e165411c 100644
--- a/examples/research_projects/lxmert/requirements.txt
+++ b/examples/research_projects/lxmert/requirements.txt
@@ -46,7 +46,7 @@ nbclient==0.5.0
nbconvert==6.0.1
nbformat==5.0.7
nest-asyncio==1.4.0
-notebook==6.4.1
+notebook==6.4.10
numpy==1.21.0
opencv-python==4.4.0.42
packaging==20.3
diff --git a/examples/research_projects/lxmert/utils.py b/examples/research_projects/lxmert/utils.py
index 59ae11d025adf4..8e830fb8359d29 100644
--- a/examples/research_projects/lxmert/utils.py
+++ b/examples/research_projects/lxmert/utils.py
@@ -231,9 +231,10 @@ def compare(in_tensor):
n2 = out_tensor.numpy()[0]
print(n1.shape, n1[0, 0, :5])
print(n2.shape, n2[0, 0, :5])
- assert np.allclose(
- n1, n2, rtol=0.01, atol=0.1
- ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
+ assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
+ f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
+ " element-wise mismatch"
+ )
raise Exception("tensors are all good")
# Hugging face functions below
diff --git a/examples/research_projects/mlm_wwm/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py
index 51c05ab0b3de60..0afa4135537a85 100644
--- a/examples/research_projects/mlm_wwm/run_mlm_wwm.py
+++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py
@@ -61,8 +61,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -72,8 +73,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -97,8 +100,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -146,8 +151,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -160,8 +167,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
diff --git a/examples/research_projects/mm-imdb/run_mmimdb.py b/examples/research_projects/mm-imdb/run_mmimdb.py
index c73aec5c874753..9f12257a10a8cb 100644
--- a/examples/research_projects/mm-imdb/run_mmimdb.py
+++ b/examples/research_projects/mm-imdb/run_mmimdb.py
@@ -356,8 +356,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
@@ -423,8 +425,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
diff --git a/examples/research_projects/movement-pruning/bertarize.py b/examples/research_projects/movement-pruning/bertarize.py
index d1e2462a304465..623b46b94386fd 100644
--- a/examples/research_projects/movement-pruning/bertarize.py
+++ b/examples/research_projects/movement-pruning/bertarize.py
@@ -103,15 +103,20 @@ def main(args):
choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
type=str,
required=True,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)"
+ ),
)
parser.add_argument(
"--threshold",
type=float,
required=False,
- help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
- "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
- "Not needed for `l0`",
+ help=(
+ "For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
+ "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
+ "Not needed for `l0`"
+ ),
)
parser.add_argument(
"--model_name_or_path",
diff --git a/examples/research_projects/movement-pruning/counts_parameters.py b/examples/research_projects/movement-pruning/counts_parameters.py
index 0dddfaaa277d76..0aec3766b3f95c 100644
--- a/examples/research_projects/movement-pruning/counts_parameters.py
+++ b/examples/research_projects/movement-pruning/counts_parameters.py
@@ -70,15 +70,20 @@ def main(args):
choices=["l0", "topK", "sigmoied_threshold"],
type=str,
required=True,
- help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
+ help=(
+ "Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement"
+ " pruning)"
+ ),
)
parser.add_argument(
"--threshold",
type=float,
required=False,
- help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
- "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
- "Not needed for `l0`",
+ help=(
+ "For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
+ "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
+ "Not needed for `l0`"
+ ),
)
parser.add_argument(
"--serialization_dir",
diff --git a/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py b/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
index 771d2078d066f8..4228050fe123b3 100644
--- a/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
+++ b/examples/research_projects/movement-pruning/emmental/modeling_bert_masked.py
@@ -80,8 +80,8 @@ def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
- "The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
+ % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
diff --git a/examples/research_projects/movement-pruning/masked_run_glue.py b/examples/research_projects/movement-pruning/masked_run_glue.py
index 57f795945b1e8f..e81cf9209c889d 100644
--- a/examples/research_projects/movement-pruning/masked_run_glue.py
+++ b/examples/research_projects/movement-pruning/masked_run_glue.py
@@ -622,8 +622,10 @@ def main():
"--max_seq_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -669,22 +671,29 @@ def main():
"--initial_warmup",
default=1,
type=int,
- help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
- "at its `initial_threshold` value (sparsity schedule).",
+ help=(
+ "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
+ "at its `initial_threshold` value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--final_warmup",
default=2,
type=int,
- help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
- "at its final_threshold value (sparsity schedule).",
+ help=(
+ "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
+ "at its final_threshold value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--pruning_method",
default="topK",
type=str,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)."
+ ),
)
parser.add_argument(
"--mask_init",
@@ -717,7 +726,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -787,8 +799,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -805,7 +819,8 @@ def main():
and not args.overwrite_output_dir
):
raise ValueError(
- f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+ f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to"
+ " overcome."
)
# Setup CUDA, GPU & distributed training
diff --git a/examples/research_projects/movement-pruning/masked_run_squad.py b/examples/research_projects/movement-pruning/masked_run_squad.py
index f1d065f1f46b16..1bd501eda51440 100644
--- a/examples/research_projects/movement-pruning/masked_run_squad.py
+++ b/examples/research_projects/movement-pruning/masked_run_squad.py
@@ -737,8 +737,10 @@ def main():
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -750,8 +752,10 @@ def main():
"--max_query_length",
default=64,
type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.",
+ help=(
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ ),
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
@@ -785,22 +789,29 @@ def main():
"--initial_warmup",
default=1,
type=int,
- help="Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
- "at its `initial_threshold` value (sparsity schedule).",
+ help=(
+ "Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays"
+ "at its `initial_threshold` value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--final_warmup",
default=2,
type=int,
- help="Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
- "at its final_threshold value (sparsity schedule).",
+ help=(
+ "Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays"
+ "at its final_threshold value (sparsity schedule)."
+ ),
)
parser.add_argument(
"--pruning_method",
default="topK",
type=str,
- help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning).",
+ help=(
+ "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
+ " sigmoied_threshold = Soft movement pruning)."
+ ),
)
parser.add_argument(
"--mask_init",
@@ -833,7 +844,10 @@ def main():
"--teacher_type",
default=None,
type=str,
- help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
+ help=(
+ "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
+ " distillation."
+ ),
)
parser.add_argument(
"--teacher_name_or_path",
@@ -883,20 +897,27 @@ def main():
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument(
"--verbose_logging",
action="store_true",
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.",
+ help=(
+ "If true, all of the warnings related to data processing will be printed. "
+ "A number of warnings are expected for a normal SQuAD evaluation."
+ ),
)
parser.add_argument(
"--lang_id",
default=0,
type=int,
- help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
+ help=(
+ "language id of input for language-specific xlm models (see"
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
+ ),
)
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
@@ -925,8 +946,10 @@ def main():
"--fp16_opt_level",
type=str,
default="O1",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
diff --git a/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py b/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
index 58ee49a1b680b3..6db6842968a52a 100644
--- a/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
+++ b/examples/research_projects/onnx/summarization/bart_onnx/generation_onnx.py
@@ -392,13 +392,14 @@ def init(
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
+ " one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
- f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
- f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
def hypo_len(self, hypo_idx: int):
@@ -508,7 +509,8 @@ def process(
if beam_idx < self.group_size:
raise ValueError(
- f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
diff --git a/examples/research_projects/onnx/summarization/run_onnx_exporter.py b/examples/research_projects/onnx/summarization/run_onnx_exporter.py
index 2a62ca9f704dbb..5d751ace8eee10 100644
--- a/examples/research_projects/onnx/summarization/run_onnx_exporter.py
+++ b/examples/research_projects/onnx/summarization/run_onnx_exporter.py
@@ -53,14 +53,16 @@ def parse_args():
"--max_length",
type=int,
default=5,
- help=("The maximum total input sequence length after tokenization."),
+ help="The maximum total input sequence length after tokenization.",
)
parser.add_argument(
"--num_beams",
type=int,
default=None,
- help="Number of beams to use for evaluation. This argument will be "
- "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+ help=(
+ "Number of beams to use for evaluation. This argument will be "
+ "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
+ ),
)
parser.add_argument(
"--model_name_or_path",
diff --git a/examples/research_projects/performer/modeling_flax_performer_utils.py b/examples/research_projects/performer/modeling_flax_performer_utils.py
index abd42ec3d9865e..915e2fa23dd98f 100644
--- a/examples/research_projects/performer/modeling_flax_performer_utils.py
+++ b/examples/research_projects/performer/modeling_flax_performer_utils.py
@@ -535,7 +535,7 @@ def dot_product_attention(
assert key.ndim == value.ndim
for ax in axis:
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
- raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
+ raise ValueError("Attention axis must be between the batch axis and the last-two axes.")
n = key.ndim
# Constructing projection tensor.
diff --git a/examples/research_projects/performer/run_mlm_performer.py b/examples/research_projects/performer/run_mlm_performer.py
index 34aa75f8a9d68c..be20342d3a49c4 100644
--- a/examples/research_projects/performer/run_mlm_performer.py
+++ b/examples/research_projects/performer/run_mlm_performer.py
@@ -98,8 +98,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
performer: bool = field(
@@ -159,8 +160,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated. Default to the max input length of the model."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated. Default to the max input length of the model."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -173,8 +176,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
diff --git a/examples/research_projects/pplm/run_pplm_discrim_train.py b/examples/research_projects/pplm/run_pplm_discrim_train.py
index ec8cd9b9facdf2..6a7351d9e6a63a 100644
--- a/examples/research_projects/pplm/run_pplm_discrim_train.py
+++ b/examples/research_projects/pplm/run_pplm_discrim_train.py
@@ -175,8 +175,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
test_loss /= len(data_loader.dataset)
print(
- "Performance on test set: "
- "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
+ "Performance on test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
)
)
@@ -309,7 +308,7 @@ def train_discriminator(
x.append(seq)
y.append(d["label"])
except Exception:
- print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
+ print("Error evaluating / tokenizing line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
@@ -349,7 +348,7 @@ def train_discriminator(
x.append(seq)
y.append(int(np.sum(d["label"]) > 0))
except Exception:
- print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
+ print("Error evaluating / tokenizing line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
@@ -370,7 +369,7 @@ def train_discriminator(
# class \t text
if dataset_fp is None:
- raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
+ raise ValueError("When generic dataset is selected, dataset_fp needs to be specified aswell.")
classes = set()
with open(dataset_fp) as f:
@@ -490,15 +489,17 @@ def train_discriminator(
type=str,
default="SST",
choices=("SST", "clickbait", "toxic", "generic"),
- help="dataset to train the discriminator on."
- "In case of generic, the dataset is expected"
- "to be a TSBV file with structure: class \\t text",
+ help=(
+ "dataset to train the discriminator on."
+ "In case of generic, the dataset is expected"
+ "to be a TSBV file with structure: class \\t text"
+ ),
)
parser.add_argument(
"--dataset_fp",
type=str,
default="",
- help="File path of the dataset to use. " "Needed only in case of generic datadset",
+ help="File path of the dataset to use. Needed only in case of generic datadset",
)
parser.add_argument(
"--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
diff --git a/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
index 4a618ed77cd536..2a089963039592 100755
--- a/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
+++ b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py
@@ -87,8 +87,10 @@
"--max_seq_length",
default=384,
type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.",
+ help=(
+ "The maximum total input sequence length after WordPiece tokenization. Sequences "
+ "longer than this will be truncated, and sequences shorter than this will be padded."
+ ),
)
parser.add_argument(
"--doc_stride",
@@ -109,8 +111,10 @@
"--max_answer_length",
default=30,
type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.",
+ help=(
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ ),
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
diff --git a/examples/research_projects/quantization-qdqbert/quant_trainer.py b/examples/research_projects/quantization-qdqbert/quant_trainer.py
index b9fbad8a4a82e8..ce1ecb6c51feac 100755
--- a/examples/research_projects/quantization-qdqbert/quant_trainer.py
+++ b/examples/research_projects/quantization-qdqbert/quant_trainer.py
@@ -51,8 +51,10 @@ def add_arguments(parser):
group.add_argument(
"--recalibrate-weights",
action="store_true",
- help="recalibrate weight amaxes by taking the max of the weights."
- " amaxes will be computed with the current quantization granularity (axis).",
+ help=(
+ "recalibrate weight amaxes by taking the max of the weights."
+ " amaxes will be computed with the current quantization granularity (axis)."
+ ),
)
diff --git a/examples/research_projects/quantization-qdqbert/run_quant_qa.py b/examples/research_projects/quantization-qdqbert/run_quant_qa.py
index 36bfb45c8ffca4..97eece4c1d0ac9 100755
--- a/examples/research_projects/quantization-qdqbert/run_quant_qa.py
+++ b/examples/research_projects/quantization-qdqbert/run_quant_qa.py
@@ -83,8 +83,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
do_calib: bool = field(default=False, metadata={"help": "Whether to run calibration of quantization ranges."})
@@ -126,37 +128,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -165,9 +176,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -181,8 +194,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -328,9 +343,9 @@ def main():
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# Preprocessing the datasets.
diff --git a/examples/research_projects/rag-end2end-retriever/README.md b/examples/research_projects/rag-end2end-retriever/README.md
index 7cee2f1ea09c84..9bff4e8c29ab0e 100644
--- a/examples/research_projects/rag-end2end-retriever/README.md
+++ b/examples/research_projects/rag-end2end-retriever/README.md
@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
+# Latest Update
+
+ā ļø Updated the rag-end2end-retriever to be compatible with PL==1.6.4 and RAY==1.13.0 (latest versions to the date 2022-June-11)
+
# Note
ā ļø This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability
@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol
# Testing
The following two bash scripts can be used to quickly test the implementation.
-1. sh ./test_run/test_rag_new_features.sh
- - Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- - This is sufficient to check the model's ability to use the set functions correctly.
-2. sh ./test_run/test_finetune.sh script
+1. sh ./test_run/test_finetune.sh script
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
+ - Please read the comments in the test_finetune.sh file.
+2. sh ./test_run/test_rag_new_features.sh
+ - Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
+ - This is sufficient to check the model's ability to use the set functions correctly.
+
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
diff --git a/examples/research_projects/rag-end2end-retriever/callbacks_rag.py b/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
index 55fc9655dff788..5f18244a7aa481 100644
--- a/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/callbacks_rag.py
@@ -31,7 +31,8 @@ def get_checkpoint_callback(output_dir, metric):
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
+ " function."
)
checkpoint_callback = ModelCheckpoint(
@@ -40,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
- every_n_val_epochs=1, # works only with PL > 1.3
+ every_n_epochs=1, # works only with PL > 1.3
)
return checkpoint_callback
diff --git a/examples/research_projects/rag-end2end-retriever/eval_rag.py b/examples/research_projects/rag-end2end-retriever/eval_rag.py
index 05f78c3d6cdf0e..a8e7abbca6ce29 100644
--- a/examples/research_projects/rag-end2end-retriever/eval_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/eval_rag.py
@@ -146,7 +146,10 @@ def get_args():
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
- help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--index_name",
@@ -174,7 +177,10 @@ def get_args():
choices=["e2e", "retrieval"],
default="e2e",
type=str,
- help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
+ help=(
+ "Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
+ " precision@k."
+ ),
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
@@ -196,9 +202,11 @@ def get_args():
default="qa",
type=str,
choices=["qa", "ans"],
- help="Format of the gold data file"
- "qa - a single line in the following format: question [tab] answer_list"
- "ans - a single line of the gold file contains the expected answer string",
+ help=(
+ "Format of the gold data file"
+ "qa - a single line in the following format: question [tab] answer_list"
+ "ans - a single line of the gold file contains the expected answer string"
+ ),
)
parser.add_argument(
"--predictions_path",
diff --git a/examples/research_projects/rag-end2end-retriever/finetune_rag.py b/examples/research_projects/rag-end2end-retriever/finetune_rag.py
index 96cbc0f7c530aa..1229870e63c696 100644
--- a/examples/research_projects/rag-end2end-retriever/finetune_rag.py
+++ b/examples/research_projects/rag-end2end-retriever/finetune_rag.py
@@ -350,6 +350,7 @@ def training_step(self, batch, batch_idx) -> Dict:
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset")
+ # To Do (@Aaron) : Useful in the future dynamic memory implementation.
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
@@ -360,10 +361,7 @@ def training_step(self, batch, batch_idx) -> Dict:
isEmUpdateBusy = False
isAddIndexBusy = False
-
- self.trainer.accelerator_connector.accelerator.barrier(
- "barrier"
- ) # waint untill the index and kb get re-initialized.
+ self.trainer.strategy.barrier("barrier")
loss_tensors = self._step(batch)
@@ -515,29 +513,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -555,7 +561,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
@@ -564,7 +573,10 @@ def add_model_specific_args(parser, root_dir):
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
- help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: sequence or token, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--context_encoder_name",
@@ -590,7 +602,10 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--gpu_order",
type=str,
- help="order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode process. I do not have many GPUs :)",
+ help=(
+ "order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode"
+ " process. I do not have many GPUs :)"
+ ),
)
parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
@@ -602,39 +617,53 @@ def add_retriever_specific_args(parser):
"--index_name",
type=str,
default=None,
- help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
+ help=(
+ "Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
+ " for a local index, or 'legacy' for the orignal one)"
+ ),
)
parser.add_argument(
"--passages_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset"),
- help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--index_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset_hnsw_index.faiss"),
- help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="ray",
- help="What implementation to use for distributed retriever? If "
- "pytorch is selected, the index is loaded on training "
- "worker 0, and torch.distributed is used to handle "
- "communication between training worker 0, and the other "
- "training workers. If ray is selected, the Ray library is "
- "used to create load the index on separate processes, "
- "and Ray handles the communication between the training "
- "workers and the retrieval actors.",
+ help=(
+ "What implementation to use for distributed retriever? If "
+ "pytorch is selected, the index is loaded on training "
+ "worker 0, and torch.distributed is used to handle "
+ "communication between training worker 0, and the other "
+ "training workers. If ray is selected, the Ray library is "
+ "used to create load the index on separate processes, "
+ "and Ray handles the communication between the training "
+ "workers and the retrieval actors."
+ ),
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
- help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Whether to use the dummy version of the dataset index. More info about custom indexes in the"
+ " RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
return parser
@@ -645,18 +674,22 @@ def add_ray_specific_args(parser):
"--ray-address",
default="auto",
type=str,
- help="The address of the Ray cluster to connect to. If not "
- "specified, Ray will attempt to automatically detect the "
- "cluster. Has no effect if pytorch is used as the distributed "
- "retriever.",
+ help=(
+ "The address of the Ray cluster to connect to. If not "
+ "specified, Ray will attempt to automatically detect the "
+ "cluster. Has no effect if pytorch is used as the distributed "
+ "retriever."
+ ),
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
- help="The number of retrieval actors to use when Ray is selected"
- "for the distributed retriever. Has no effect when "
- "distributed_retriever is set to pytorch.",
+ help=(
+ "The number of retrieval actors to use when Ray is selected"
+ "for the distributed retriever. Has no effect when "
+ "distributed_retriever is set to pytorch."
+ ),
)
return parser
@@ -686,10 +719,10 @@ def main(args=None, model=None) -> GenerativeQAModule:
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
- raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
+ raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
- ray.init(address=args.ray_address)
+ ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
diff --git a/examples/research_projects/rag-end2end-retriever/lightning_base.py b/examples/research_projects/rag-end2end-retriever/lightning_base.py
index 1df0fae5849831..84842944059a7c 100644
--- a/examples/research_projects/rag-end2end-retriever/lightning_base.py
+++ b/examples/research_projects/rag-end2end-retriever/lightning_base.py
@@ -5,7 +5,6 @@
from typing import Any, Dict
import pytorch_lightning as pl
-from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
@@ -333,8 +332,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
@@ -384,24 +385,22 @@ def generic_train(
train_params = {}
- # TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
- train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
- train_params["accelerator"] = "ddp"
+ train_params["accelerator"] = "auto"
+ train_params["strategy"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
- # train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
- train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
+ train_params["profiler"] = None
+ train_params["devices"] = "auto"
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
logger=logger,
- plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
val_check_interval=1,
num_sanity_val_steps=2,
**train_params,
@@ -410,6 +409,6 @@ def generic_train(
if args.do_train:
trainer.fit(model)
- # else:
- # print("RAG modeling tests with new set functions successfuly executed!")
+ else:
+ print("RAG modeling tests with new set functions successfuly executed!")
return trainer
diff --git a/examples/research_projects/rag-end2end-retriever/requirements.txt b/examples/research_projects/rag-end2end-retriever/requirements.txt
index aca89c78e88c0d..32025229d07439 100644
--- a/examples/research_projects/rag-end2end-retriever/requirements.txt
+++ b/examples/research_projects/rag-end2end-retriever/requirements.txt
@@ -1,7 +1,7 @@
-faiss-cpu >= 1.7.0
-datasets >= 1.6.2
-psutil >= 5.7.0
-torch >= 1.4.0
-pytorch-lightning
+faiss-cpu >= 1.7.2
+datasets
+psutil >= 5.9.1
+torch >= 1.11.0
+pytorch-lightning == 1.6.4
nvidia-ml-py3 == 7.352.0
-ray >= 1.3.0
+ray >= 1.13.0
\ No newline at end of file
diff --git a/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh b/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
index bbf69b05380e9c..c44d110d20046a 100755
--- a/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
+++ b/examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
@@ -44,11 +44,14 @@ python finetune_rag.py \
--num_retrieval_workers 4 \
--index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
- --index_gpus 1 \
- --gpu_order [6,7,8,9,0,1,2,3,5,4] \
+ --index_gpus 2 \
+ --gpu_order [2,3,4,5,6,7,8,9,0,1] \
--indexing_freq 5
# Stop the Ray cluster.
ray stop
+
+#CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9,0,1 sh ./test_run/test_finetune.sh
+#Make sure --gpu_order is same.
\ No newline at end of file
diff --git a/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py b/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
index 213aa8d882fc25..432111a2784c37 100644
--- a/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
+++ b/examples/research_projects/rag-end2end-retriever/use_own_knowledge_dataset.py
@@ -121,7 +121,10 @@ class RagExampleArguments:
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
- "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
+ "help": (
+ "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
+ " 'facebook/dpr-ctx_encoder-multiset-base'"
+ )
},
)
output_dir: Optional[str] = field(
@@ -155,7 +158,9 @@ class IndexHnswArguments:
m: int = field(
default=128,
metadata={
- "help": "The number of bi-directional links created for every new element during the HNSW index construction."
+ "help": (
+ "The number of bi-directional links created for every new element during the HNSW index construction."
+ )
},
)
diff --git a/examples/research_projects/rag/callbacks_rag.py b/examples/research_projects/rag/callbacks_rag.py
index a2d87f82247c4a..af1595b08efdf6 100644
--- a/examples/research_projects/rag/callbacks_rag.py
+++ b/examples/research_projects/rag/callbacks_rag.py
@@ -29,7 +29,8 @@ def get_checkpoint_callback(output_dir, metric):
exp = "{val_avg_em:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this"
+ " function."
)
checkpoint_callback = ModelCheckpoint(
diff --git a/examples/research_projects/rag/consolidate_rag_checkpoint.py b/examples/research_projects/rag/consolidate_rag_checkpoint.py
index b9ed7ec0f8115e..39ba7e91f6c3a6 100644
--- a/examples/research_projects/rag/consolidate_rag_checkpoint.py
+++ b/examples/research_projects/rag/consolidate_rag_checkpoint.py
@@ -80,7 +80,10 @@ def consolidate(
parser.add_argument(
"--config_name_or_path",
type=str,
- help="Identifier of the model config to use, if not provided, resolves to a base config for a given ``model_type``",
+ help=(
+ "Identifier of the model config to use, if not provided, resolves to a base config for a given"
+ " ``model_type``"
+ ),
)
args = parser.parse_args()
diff --git a/examples/research_projects/rag/eval_rag.py b/examples/research_projects/rag/eval_rag.py
index 05f78c3d6cdf0e..a8e7abbca6ce29 100644
--- a/examples/research_projects/rag/eval_rag.py
+++ b/examples/research_projects/rag/eval_rag.py
@@ -146,7 +146,10 @@ def get_args():
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
- help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
parser.add_argument(
"--index_name",
@@ -174,7 +177,10 @@ def get_args():
choices=["e2e", "retrieval"],
default="e2e",
type=str,
- help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
+ help=(
+ "Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
+ " precision@k."
+ ),
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
@@ -196,9 +202,11 @@ def get_args():
default="qa",
type=str,
choices=["qa", "ans"],
- help="Format of the gold data file"
- "qa - a single line in the following format: question [tab] answer_list"
- "ans - a single line of the gold file contains the expected answer string",
+ help=(
+ "Format of the gold data file"
+ "qa - a single line in the following format: question [tab] answer_list"
+ "ans - a single line of the gold file contains the expected answer string"
+ ),
)
parser.add_argument(
"--predictions_path",
diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py
index 2fd4ef7659c543..f5cef614e2d9f3 100644
--- a/examples/research_projects/rag/finetune_rag.py
+++ b/examples/research_projects/rag/finetune_rag.py
@@ -383,29 +383,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=128,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -423,7 +431,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
@@ -432,7 +443,10 @@ def add_model_specific_args(parser, root_dir):
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
- help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
+ help=(
+ "RAG model type: sequence or token, if none specified, the type is inferred from the"
+ " model_name_or_path"
+ ),
)
return parser
@@ -442,39 +456,53 @@ def add_retriever_specific_args(parser):
"--index_name",
type=str,
default=None,
- help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
+ help=(
+ "Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom'"
+ " for a local index, or 'legacy' for the orignal one)"
+ ),
)
parser.add_argument(
"--passages_path",
type=str,
default=None,
- help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--index_path",
type=str,
default=None,
- help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Path to the faiss index for custom index. More info about custom indexes in the RagRetriever"
+ " documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="pytorch",
- help="What implementation to use for distributed retriever? If "
- "pytorch is selected, the index is loaded on training "
- "worker 0, and torch.distributed is used to handle "
- "communication between training worker 0, and the other "
- "training workers. If ray is selected, the Ray library is "
- "used to create load the index on separate processes, "
- "and Ray handles the communication between the training "
- "workers and the retrieval actors.",
+ help=(
+ "What implementation to use for distributed retriever? If "
+ "pytorch is selected, the index is loaded on training "
+ "worker 0, and torch.distributed is used to handle "
+ "communication between training worker 0, and the other "
+ "training workers. If ray is selected, the Ray library is "
+ "used to create load the index on separate processes, "
+ "and Ray handles the communication between the training "
+ "workers and the retrieval actors."
+ ),
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
- help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
+ help=(
+ "Whether to use the dummy version of the dataset index. More info about custom indexes in the"
+ " RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`"
+ ),
)
return parser
@@ -485,18 +513,22 @@ def add_ray_specific_args(parser):
"--ray-address",
default="auto",
type=str,
- help="The address of the Ray cluster to connect to. If not "
- "specified, Ray will attempt to automatically detect the "
- "cluster. Has no effect if pytorch is used as the distributed "
- "retriever.",
+ help=(
+ "The address of the Ray cluster to connect to. If not "
+ "specified, Ray will attempt to automatically detect the "
+ "cluster. Has no effect if pytorch is used as the distributed "
+ "retriever."
+ ),
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
- help="The number of retrieval actors to use when Ray is selected"
- "for the distributed retriever. Has no effect when "
- "distributed_retriever is set to pytorch.",
+ help=(
+ "The number of retrieval actors to use when Ray is selected"
+ "for the distributed retriever. Has no effect when "
+ "distributed_retriever is set to pytorch."
+ ),
)
return parser
@@ -514,7 +546,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
- raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
+ raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address, namespace="rag")
diff --git a/examples/research_projects/rag/lightning_base.py b/examples/research_projects/rag/lightning_base.py
index 1e0f67627e7c34..77830a4760ad39 100644
--- a/examples/research_projects/rag/lightning_base.py
+++ b/examples/research_projects/rag/lightning_base.py
@@ -321,8 +321,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/research_projects/rag/use_own_knowledge_dataset.py b/examples/research_projects/rag/use_own_knowledge_dataset.py
index 269765caab8653..dc08f508228abc 100644
--- a/examples/research_projects/rag/use_own_knowledge_dataset.py
+++ b/examples/research_projects/rag/use_own_knowledge_dataset.py
@@ -154,7 +154,10 @@ class RagExampleArguments:
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
- "help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
+ "help": (
+ "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
+ " 'facebook/dpr-ctx_encoder-multiset-base'"
+ )
},
)
output_dir: Optional[str] = field(
@@ -188,7 +191,9 @@ class IndexHnswArguments:
m: int = field(
default=128,
metadata={
- "help": "The number of bi-directional links created for every new element during the HNSW index construction."
+ "help": (
+ "The number of bi-directional links created for every new element during the HNSW index construction."
+ )
},
)
diff --git a/examples/research_projects/robust-speech-event/eval.py b/examples/research_projects/robust-speech-event/eval.py
index 53cd244daf7549..32e3d1f2c729f8 100755
--- a/examples/research_projects/robust-speech-event/eval.py
+++ b/examples/research_projects/robust-speech-event/eval.py
@@ -24,7 +24,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
# print & log results
- result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
+ result_str = f"WER: {wer_result}\nCER: {cer_result}"
print(result_str)
with open(f"{dataset_id}_eval_results.txt", "w") as f:
diff --git a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
index 2317367e7cc3c3..521036c78e4ba5 100755
--- a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
+++ b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_bnb.py
@@ -103,9 +103,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -115,8 +117,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -175,15 +180,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -197,7 +206,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=20.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -206,17 +218,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -234,10 +250,12 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -406,9 +424,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.text_column_name not in raw_datasets["train"].column_names:
@@ -743,7 +761,10 @@ def compute_metrics(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
index 9e69178088f608..d357bc469649ea 100644
--- a/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
+++ b/examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
@@ -102,9 +102,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -114,8 +116,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -147,8 +152,10 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train+validation",
metadata={
- "help": "The name of the training data set split to use (via the datasets library). Defaults to "
- "'train+validation'"
+ "help": (
+ "The name of the training data set split to use (via the datasets library). Defaults to "
+ "'train+validation'"
+ )
},
)
eval_split_name: str = field(
@@ -175,22 +182,28 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
shuffle_buffer_size: Optional[int] = field(
default=500,
metadata={
- "help": "The number of streamed examples to download before shuffling them. The large the buffer, "
- "the closer it is to real offline shuffling."
+ "help": (
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
+ "the closer it is to real offline shuffling."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -208,26 +221,32 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
@@ -393,9 +412,9 @@ def load_streaming_dataset(split, sampling_rate, **kwargs):
if data_args.audio_column_name not in raw_column_names["train"]:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_column_names['train'])}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_column_names['train'])}."
)
if data_args.text_column_name not in raw_column_names["train"]:
@@ -641,7 +660,10 @@ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
"finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition",
"tags": ["automatic-speech-recognition", data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
}
if "common_voice" in data_args.dataset_name:
diff --git a/examples/research_projects/self-training-text-classification/finetuning.py b/examples/research_projects/self-training-text-classification/finetuning.py
index 8ad92359b619e0..eeb0a285dff987 100644
--- a/examples/research_projects/self-training-text-classification/finetuning.py
+++ b/examples/research_projects/self-training-text-classification/finetuning.py
@@ -100,15 +100,19 @@ class FTDataArguments:
max_length: Optional[int] = dataclasses.field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: Optional[bool] = dataclasses.field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
@@ -147,7 +151,10 @@ class FTTrainingArguments:
weight_decay: Optional[float] = dataclasses.field(
default=0.0,
metadata={
- "help": "The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] optimizer."
+ "help": (
+ "The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in"
+ " [`AdamW`] optimizer."
+ )
},
)
learning_rate: Optional[float] = dataclasses.field(
@@ -157,13 +164,18 @@ class FTTrainingArguments:
gradient_accumulation_steps: Optional[int] = dataclasses.field(
default=1,
metadata={
- "help": "Number of updates steps to accumulate the gradients for, before performing a backward/update pass."
+ "help": (
+ "Number of updates steps to accumulate the gradients for, before performing a backward/update pass."
+ )
},
)
max_steps: Optional[int] = dataclasses.field(
default=-1,
metadata={
- "help": "If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`."
+ "help": (
+ "If set to a positive number, the total number of training steps to perform. Overrides"
+ " `num_train_epochs`."
+ )
},
)
lr_scheduler_type: Optional[str] = dataclasses.field(
@@ -172,7 +184,10 @@ class FTTrainingArguments:
warmup_steps: Optional[int] = dataclasses.field(
default=1,
metadata={
- "help": "Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`."
+ "help": (
+ "Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of"
+ " `warmup_ratio`."
+ )
},
)
evaluation_strategy: Optional[str] = dataclasses.field(
diff --git a/examples/research_projects/seq2seq-distillation/callbacks.py b/examples/research_projects/seq2seq-distillation/callbacks.py
index 388b6d53ddd347..6f6ed5dd58acfd 100644
--- a/examples/research_projects/seq2seq-distillation/callbacks.py
+++ b/examples/research_projects/seq2seq-distillation/callbacks.py
@@ -93,7 +93,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
- f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
+ f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to"
+ " this function."
)
checkpoint_callback = ModelCheckpoint(
diff --git a/examples/research_projects/seq2seq-distillation/distillation.py b/examples/research_projects/seq2seq-distillation/distillation.py
index 1f9106f0c0a76b..5a403be8d56212 100755
--- a/examples/research_projects/seq2seq-distillation/distillation.py
+++ b/examples/research_projects/seq2seq-distillation/distillation.py
@@ -52,9 +52,10 @@ def __init__(self, hparams):
student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config)
- assert (
- student.config.model_type == teacher.config.model_type
- ), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
+ assert student.config.model_type == teacher.config.model_type, (
+ f"teacher, student model types should be the same, got {student.config.model_type} !="
+ f" {teacher.config.model_type}"
+ )
if student.config.model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)
diff --git a/examples/research_projects/seq2seq-distillation/finetune.py b/examples/research_projects/seq2seq-distillation/finetune.py
index 5874509377aa73..c20b361d583631 100755
--- a/examples/research_projects/seq2seq-distillation/finetune.py
+++ b/examples/research_projects/seq2seq-distillation/finetune.py
@@ -303,29 +303,37 @@ def add_model_specific_args(parser, root_dir):
"--max_source_length",
default=1024,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
- help="The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded.",
+ help=(
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ ),
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
@@ -353,7 +361,10 @@ def add_model_specific_args(parser, root_dir):
type=int,
default=-1,
required=False,
- help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
+ help=(
+ "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
+ " val_check_interval will effect it."
+ ),
)
return parser
diff --git a/examples/research_projects/seq2seq-distillation/lightning_base.py b/examples/research_projects/seq2seq-distillation/lightning_base.py
index b7f53076e3bc31..b3104a25a8b129 100644
--- a/examples/research_projects/seq2seq-distillation/lightning_base.py
+++ b/examples/research_projects/seq2seq-distillation/lightning_base.py
@@ -312,8 +312,10 @@ def add_generic_args(parser, root_dir) -> None:
"--fp16_opt_level",
type=str,
default="O2",
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html",
+ help=(
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
+ "See details at https://nvidia.github.io/apex/amp.html"
+ ),
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
diff --git a/examples/research_projects/seq2seq-distillation/make_student.py b/examples/research_projects/seq2seq-distillation/make_student.py
index 8d70292d0e5a09..a4021505b998e0 100644
--- a/examples/research_projects/seq2seq-distillation/make_student.py
+++ b/examples/research_projects/seq2seq-distillation/make_student.py
@@ -58,7 +58,8 @@ def pick_layers_to_copy(n_student, n_teacher):
except KeyError:
if n_student != n_teacher:
warnings.warn(
- f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
+ f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
+ f" {n_student}"
)
return list(range(n_student))
@@ -144,7 +145,8 @@ def create_student_by_copying_alternating_layers(
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
- f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
+ f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to"
+ f" {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy
diff --git a/examples/research_projects/seq2seq-distillation/run_eval.py b/examples/research_projects/seq2seq-distillation/run_eval.py
index de752c7df189e5..3f685884e8e893 100755
--- a/examples/research_projects/seq2seq-distillation/run_eval.py
+++ b/examples/research_projects/seq2seq-distillation/run_eval.py
@@ -108,7 +108,10 @@ def run_generate(verbose=True):
nargs="?",
type=str,
const=datetime_now(),
- help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
+ help=(
+ "use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g."
+ " lang=en-ru. If no value is passed, the current datetime string will be used."
+ ),
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
diff --git a/examples/research_projects/tapex/run_tabfact_with_tapex.py b/examples/research_projects/tapex/run_tabfact_with_tapex.py
index 0ed573ad9c1adb..19c21c33948edb 100644
--- a/examples/research_projects/tapex/run_tabfact_with_tapex.py
+++ b/examples/research_projects/tapex/run_tabfact_with_tapex.py
@@ -77,8 +77,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -87,29 +89,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -164,8 +174,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
diff --git a/examples/research_projects/tapex/run_wikisql_with_tapex.py b/examples/research_projects/tapex/run_wikisql_with_tapex.py
index 594c83cb6be53a..461bfbec9ae3c7 100644
--- a/examples/research_projects/tapex/run_wikisql_with_tapex.py
+++ b/examples/research_projects/tapex/run_wikisql_with_tapex.py
@@ -82,8 +82,10 @@ class ModelArguments:
tokenizer_name: Optional[str] = field(
default=None,
metadata={
- "help": "Pretrained tokenizer name or path if not the same as model_name. "
- "By default we use BART-large tokenizer for TAPEX-large."
+ "help": (
+ "Pretrained tokenizer name or path if not the same as model_name. "
+ "By default we use BART-large tokenizer for TAPEX-large."
+ )
},
)
cache_dir: Optional[str] = field(
@@ -101,8 +103,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -125,14 +129,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -145,60 +150,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
diff --git a/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py b/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
index 4398309566a8f4..1750adc546f017 100644
--- a/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
+++ b/examples/research_projects/tapex/run_wikitablequestions_with_tapex.py
@@ -80,8 +80,10 @@ class ModelArguments:
tokenizer_name: Optional[str] = field(
default=None,
metadata={
- "help": "Pretrained tokenizer name or path if not the same as model_name. "
- "By default we use BART-large tokenizer for TAPEX-large."
+ "help": (
+ "Pretrained tokenizer name or path if not the same as model_name. "
+ "By default we use BART-large tokenizer for TAPEX-large."
+ )
},
)
cache_dir: Optional[str] = field(
@@ -99,8 +101,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -123,14 +127,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -143,60 +148,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
diff --git a/examples/research_projects/visual_bert/demo.ipynb b/examples/research_projects/visual_bert/demo.ipynb
index a025e419a3c67b..14a65ce3df3396 100644
--- a/examples/research_projects/visual_bert/demo.ipynb
+++ b/examples/research_projects/visual_bert/demo.ipynb
@@ -4,7 +4,7 @@
"cell_type": "code",
"execution_count": 1,
"source": [
- "#%pip install-r requirements.txt"
+ "# %pip install-r requirements.txt"
],
"outputs": [],
"metadata": {}
diff --git a/examples/research_projects/visual_bert/modeling_frcnn.py b/examples/research_projects/visual_bert/modeling_frcnn.py
index 39a0c6aea8787d..33c1133e9589f4 100644
--- a/examples/research_projects/visual_bert/modeling_frcnn.py
+++ b/examples/research_projects/visual_bert/modeling_frcnn.py
@@ -592,7 +592,7 @@ def __call__(self, match_quality_matrix):
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
- for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
low_high = (matched_vals >= low) & (matched_vals < high)
match_labels[low_high] = l
@@ -1037,9 +1037,9 @@ def make_stage(
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
- assert len(v) == num_blocks, (
- f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
- )
+ assert (
+ len(v) == num_blocks
+ ), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
@@ -1401,7 +1401,7 @@ def num_cell_anchors(self):
def grid_anchors(self, grid_sizes):
anchors = []
- for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
+ for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -1708,10 +1708,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
- assert (
- from_tf
- ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
- pretrained_model_name_or_path + ".index"
+ assert from_tf, (
+ "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
+ .format(pretrained_model_name_or_path + ".index")
)
archive_file = pretrained_model_name_or_path + ".index"
else:
@@ -1797,26 +1796,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if len(unexpected_keys) > 0:
print(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
- f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
- f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
- f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
print(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
- f"and are newly initialized: {missing_keys}\n"
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
print(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
- f"If your task is similar to the task the model of the checkpoint was trained on, "
- f"you can already use {model.__class__.__name__} for predictions without further training."
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
)
if len(error_msgs) > 0:
raise RuntimeError(
diff --git a/examples/research_projects/visual_bert/requirements.txt b/examples/research_projects/visual_bert/requirements.txt
index 9b3e500040688f..fc3b85e165411c 100644
--- a/examples/research_projects/visual_bert/requirements.txt
+++ b/examples/research_projects/visual_bert/requirements.txt
@@ -46,7 +46,7 @@ nbclient==0.5.0
nbconvert==6.0.1
nbformat==5.0.7
nest-asyncio==1.4.0
-notebook==6.4.1
+notebook==6.4.10
numpy==1.21.0
opencv-python==4.4.0.42
packaging==20.3
diff --git a/examples/research_projects/visual_bert/utils.py b/examples/research_projects/visual_bert/utils.py
index 59ae11d025adf4..8e830fb8359d29 100644
--- a/examples/research_projects/visual_bert/utils.py
+++ b/examples/research_projects/visual_bert/utils.py
@@ -231,9 +231,10 @@ def compare(in_tensor):
n2 = out_tensor.numpy()[0]
print(n1.shape, n1[0, 0, :5])
print(n2.shape, n2[0, 0, :5])
- assert np.allclose(
- n1, n2, rtol=0.01, atol=0.1
- ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
+ assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
+ f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
+ " element-wise mismatch"
+ )
raise Exception("tensors are all good")
# Hugging face functions below
diff --git a/examples/research_projects/wav2vec2/run_asr.py b/examples/research_projects/wav2vec2/run_asr.py
index 9b031cca1972e1..bb34e0a0c71a83 100755
--- a/examples/research_projects/wav2vec2/run_asr.py
+++ b/examples/research_projects/wav2vec2/run_asr.py
@@ -99,7 +99,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
target_text_column: Optional[str] = field(
@@ -121,7 +123,10 @@ class DataTrainingArguments:
orthography: Optional[str] = field(
default="librispeech",
metadata={
- "help": "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or 'buckwalter'."
+ "help": (
+ "Orthography used for normalization and tokenization: 'librispeech' (default), 'timit', or"
+ " 'buckwalter'."
+ )
},
)
overwrite_cache: bool = field(
@@ -392,11 +397,13 @@ def filter_by_max_duration(example):
val_dataset = val_dataset.filter(filter_by_max_duration, remove_columns=["duration_in_seconds"])
if len(train_dataset) > old_train_size:
logger.warning(
- f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than {data_args.max_duration_in_seconds} second(s)."
+ f"Filtered out {len(train_dataset) - old_train_size} train example(s) longer than"
+ f" {data_args.max_duration_in_seconds} second(s)."
)
if len(val_dataset) > old_val_size:
logger.warning(
- f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than {data_args.max_duration_in_seconds} second(s)."
+ f"Filtered out {len(val_dataset) - old_val_size} validation example(s) longer than"
+ f" {data_args.max_duration_in_seconds} second(s)."
)
logger.info(f"Split sizes: {len(train_dataset)} train and {len(val_dataset)} validation.")
diff --git a/examples/research_projects/wav2vec2/run_common_voice.py b/examples/research_projects/wav2vec2/run_common_voice.py
index 5825c1feb10bb2..b8480d3c7d1c90 100644
--- a/examples/research_projects/wav2vec2/run_common_voice.py
+++ b/examples/research_projects/wav2vec2/run_common_voice.py
@@ -79,9 +79,11 @@ class ModelArguments:
mask_time_prob: Optional[float] = field(
default=0.05,
metadata={
- "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
+ "help": (
+ "Propability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
+ )
},
)
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
@@ -116,15 +118,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: List[str] = list_field(
diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py
index 248f32443f0488..fb430d14074836 100755
--- a/examples/research_projects/wav2vec2/run_pretrain.py
+++ b/examples/research_projects/wav2vec2/run_pretrain.py
@@ -104,7 +104,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
- "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ "help": (
+ "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
speech_file_column: Optional[str] = field(
@@ -200,7 +202,6 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
- device=batch["input_values"].device,
attention_mask=attention_mask,
min_masks=2,
)
@@ -369,7 +370,8 @@ def normalize(batch):
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
- "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
+ " ``config.feat_extract_norm='layer'"
)
model = Wav2Vec2ForPreTraining(config)
diff --git a/examples/research_projects/xtreme-s/run_xtreme_s.py b/examples/research_projects/xtreme-s/run_xtreme_s.py
index a186d4b7cee77d..972c6d5462ff8e 100644
--- a/examples/research_projects/xtreme-s/run_xtreme_s.py
+++ b/examples/research_projects/xtreme-s/run_xtreme_s.py
@@ -89,7 +89,7 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None,
metadata={
- "help": "Where do you want to store the pretrained models and datasets downloaded from " "huggingface.co"
+ "help": "Where do you want to store the pretrained models and datasets downloaded from huggingface.co"
},
)
freeze_feature_encoder: bool = field(
@@ -115,9 +115,11 @@ class ModelArguments:
mask_time_prob: float = field(
default=0.05,
metadata={
- "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
- "vectors will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
+ "vectors will be masked along the time axis."
+ )
},
)
mask_time_length: int = field(
@@ -127,8 +129,11 @@ class ModelArguments:
mask_feature_prob: float = field(
default=0.0,
metadata={
- "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
- "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
+ "help": (
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
+ " bins will be masked along the time axis."
+ )
},
)
mask_feature_length: int = field(
@@ -162,8 +167,10 @@ class DataTrainingArguments:
task: str = field(
default=None,
metadata={
- "help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
- "'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
+ "help": (
+ "The task name of the benchmark to use (via the datasets library). Should be on of: "
+ "'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
+ )
},
)
language: str = field(
@@ -173,10 +180,12 @@ class DataTrainingArguments:
language_group: str = field(
default=None,
metadata={
- "help": "The language group to select a subset of languages to train on. "
- "This option is only used the 'fleurs-asr' task. Should be one of: "
- "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
- "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
+ "help": (
+ "The language group to select a subset of languages to train on. "
+ "This option is only used the 'fleurs-asr' task. Should be one of: "
+ "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
+ "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
+ )
},
)
train_split_name: str = field(
@@ -188,14 +197,15 @@ class DataTrainingArguments:
eval_split_name: str = field(
default="validation",
metadata={
- "help": "The name of the evaluation dataset split to use (via the datasets library). "
- "Defaults to 'validation'"
+ "help": (
+ "The name of the evaluation dataset split to use (via the datasets library). Defaults to 'validation'"
+ )
},
)
predict_split_name: str = field(
default="test",
metadata={
- "help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'"
+ "help": "The name of the prediction dataset split to use (via the datasets library). Defaults to 'test'"
},
)
audio_column_name: str = field(
@@ -205,8 +215,10 @@ class DataTrainingArguments:
target_column_name: str = field(
default=None,
metadata={
- "help": "The name of the dataset column containing the target data "
- "(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None."
+ "help": (
+ "The name of the dataset column containing the target data (transcription/translation/label). If None,"
+ " the name will be inferred from the task. Defaults to None."
+ )
},
)
overwrite_cache: bool = field(
@@ -219,22 +231,28 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
chars_to_ignore: Optional[List[str]] = list_field(
@@ -244,7 +262,10 @@ class DataTrainingArguments:
max_duration_in_seconds: float = field(
default=30.0,
metadata={
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
+ "help": (
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
+ " 'max_duration_in_seconds`"
+ )
},
)
min_duration_in_seconds: float = field(
@@ -253,17 +274,21 @@ class DataTrainingArguments:
preprocessing_only: bool = field(
default=False,
metadata={
- "help": "Whether to only do data preprocessing and skip training. "
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
- "so that the cached datasets can consequently be loaded in distributed training"
+ "help": (
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
+ " can consequently be loaded in distributed training"
+ )
},
)
use_auth_token: bool = field(
default=False,
metadata={
- "help": "If :obj:`True`, will use the token generated when running"
- ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ "help": (
+ "If :obj:`True`, will use the token generated when running"
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
+ )
},
)
unk_token: str = field(
@@ -281,17 +306,21 @@ class DataTrainingArguments:
phoneme_language: Optional[str] = field(
default=None,
metadata={
- "help": "The target language that should be used be"
- " passed to the tokenizer for tokenization. Note that"
- " this is only relevant if the model classifies the"
- " input audio to a sequence of phoneme sequences."
+ "help": (
+ "The target language that should be used be"
+ " passed to the tokenizer for tokenization. Note that"
+ " this is only relevant if the model classifies the"
+ " input audio to a sequence of phoneme sequences."
+ )
},
)
per_lang_metrics: bool = field(
default=True,
metadata={
- "help": "If `True`, compute the test metrics separately for each language, and average the results. "
- "If `False` compute the average test metrics in a single pass for all languages at once."
+ "help": (
+ "If `True`, compute the test metrics separately for each language, and average the results. "
+ "If `False` compute the average test metrics in a single pass for all languages at once."
+ )
},
)
@@ -446,7 +475,7 @@ def main():
if task_name is None:
raise ValueError(
- "Set --task should be set to '' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
+ "Set --task should be set to '' (e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if lang_id is None:
raise ValueError(
@@ -481,9 +510,9 @@ def main():
if data_args.audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
- f"{', '.join(raw_datasets['train'].column_names)}."
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
+ f" {', '.join(raw_datasets['train'].column_names)}."
)
if target_column_name not in raw_datasets["train"].column_names:
@@ -903,7 +932,10 @@ def compute_classification_metric(pred):
"finetuned_from": model_args.model_name_or_path,
"tasks": task_name,
"tags": [task_name, data_args.dataset_name],
- "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
+ "dataset_args": (
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
+ f" {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}"
+ ),
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
"language": data_args.language,
}
diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py
index 3598ad668a96cc..46c8d339d970c3 100755
--- a/examples/tensorflow/language-modeling/run_clm.py
+++ b/examples/tensorflow/language-modeling/run_clm.py
@@ -53,6 +53,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -73,8 +74,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -84,8 +86,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -109,8 +113,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -150,9 +156,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
- "help": "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -166,15 +174,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
keep_linebreaks: bool = field(
@@ -221,6 +233,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args, framework="tensorflow")
+
# Sanity checks
if data_args.dataset_name is None and data_args.train_file is None and data_args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
@@ -412,7 +428,8 @@ def group_texts(examples):
eval_dataset = lm_datasets["validation"]
else:
logger.info(
- f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
+ f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation"
+ " as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage / 100
diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py
index 8b32070b2dd1e0..46b27dab662519 100755
--- a/examples/tensorflow/language-modeling/run_mlm.py
+++ b/examples/tensorflow/language-modeling/run_mlm.py
@@ -55,6 +55,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -74,8 +75,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "The model checkpoint for weights initialization."
- "Don't set if you want to train a model from scratch."
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
},
)
model_type: Optional[str] = field(
@@ -85,8 +87,10 @@ class ModelArguments:
config_overrides: Optional[str] = field(
default=None,
metadata={
- "help": "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
},
)
config_name: Optional[str] = field(
@@ -110,8 +114,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -151,8 +157,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ )
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -169,22 +177,28 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -229,6 +243,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_mlm", model_args, data_args, framework="tensorflow")
+
# Sanity checks
if data_args.dataset_name is None and data_args.train_file is None and data_args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
@@ -456,7 +474,8 @@ def group_texts(examples):
eval_dataset = tokenized_datasets["validation"]
else:
logger.info(
- f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
+ f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation"
+ " as provided in data_args"
)
train_indices, val_indices = train_test_split(
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage / 100
diff --git a/examples/tensorflow/multiple-choice/run_swag.py b/examples/tensorflow/multiple-choice/run_swag.py
index a1f39eeeb01143..1c88f0db51b0df 100644
--- a/examples/tensorflow/multiple-choice/run_swag.py
+++ b/examples/tensorflow/multiple-choice/run_swag.py
@@ -44,11 +44,11 @@
set_seed,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-from transformers.utils import PaddingStrategy, check_min_version
+from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
logger = logging.getLogger(__name__)
@@ -156,8 +156,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -183,30 +185,38 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. If passed, sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to the maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to the maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
@@ -236,6 +246,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_swag", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
diff --git a/examples/tensorflow/question-answering/run_qa.py b/examples/tensorflow/question-answering/run_qa.py
index 877fe8800999ad..55465f345adfc9 100755
--- a/examples/tensorflow/question-answering/run_qa.py
+++ b/examples/tensorflow/question-answering/run_qa.py
@@ -41,12 +41,12 @@
TFTrainingArguments,
set_seed,
)
-from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version
+from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version, send_example_telemetry
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
logger = logging.getLogger(__name__)
@@ -78,8 +78,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -115,37 +117,46 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=384,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
- "be faster on GPU but will be slower on TPU)."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
+ " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
version_2_with_negative: bool = field(
@@ -154,9 +165,11 @@ class DataTrainingArguments:
null_score_diff_threshold: float = field(
default=0.0,
metadata={
- "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
- "the score of the null answer minus this threshold, the null answer is selected for this example. "
- "Only useful when `version_2_with_negative=True`."
+ "help": (
+ "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ )
},
)
doc_stride: int = field(
@@ -170,8 +183,10 @@ class DataTrainingArguments:
max_answer_length: int = field(
default=30,
metadata={
- "help": "The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another."
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
},
)
@@ -227,6 +242,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_qa", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
@@ -330,9 +349,9 @@ def main():
# region Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
- "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
- "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
- "requirement"
+ "This example script only works for models that have a fast tokenizer. Checkout the big table of models at"
+ " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet"
+ " this requirement"
)
# endregion
diff --git a/examples/tensorflow/summarization/run_summarization.py b/examples/tensorflow/summarization/run_summarization.py
index 6c4f1e5a9ed9ef..e67dc9b2cc607f 100644
--- a/examples/tensorflow/summarization/run_summarization.py
+++ b/examples/tensorflow/summarization/run_summarization.py
@@ -44,13 +44,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version, is_offline_mode
+from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -99,8 +99,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -131,14 +133,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -151,60 +154,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -329,6 +348,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_summarization", model_args, data_args, framework="tensorflow")
# endregion
# region Logging
diff --git a/examples/tensorflow/text-classification/run_glue.py b/examples/tensorflow/text-classification/run_glue.py
index c36476120eab30..9268d755e03f8a 100644
--- a/examples/tensorflow/text-classification/run_glue.py
+++ b/examples/tensorflow/text-classification/run_glue.py
@@ -39,7 +39,7 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
# region Helper functions
@@ -61,7 +61,7 @@ def on_epoch_end(self, epoch, logs=None):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
task_to_keys = {
"cola": ("sentence", None),
@@ -99,8 +99,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -109,29 +111,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
@@ -171,8 +181,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -194,6 +206,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_glue", model_args, data_args, framework="tensorflow")
+
if not (training_args.do_train or training_args.do_eval or training_args.do_predict):
exit("Must specify at least one of --do_train, --do_eval or --do_predict!")
# endregion
diff --git a/examples/tensorflow/text-classification/run_text_classification.py b/examples/tensorflow/text-classification/run_text_classification.py
index 3f3d64b6236d7b..210a30344dbc0e 100644
--- a/examples/tensorflow/text-classification/run_text_classification.py
+++ b/examples/tensorflow/text-classification/run_text_classification.py
@@ -37,7 +37,7 @@
TFTrainingArguments,
set_seed,
)
-from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME
+from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, send_example_telemetry
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # Reduce the amount of console output from TF
@@ -85,8 +85,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -95,30 +97,38 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
- "Data will always be padded when using TPUs."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "Data will always be padded when using TPUs."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of test examples to this "
+ "value if set."
+ )
},
)
@@ -162,8 +172,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -184,6 +196,11 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_text_classification", model_args, data_args, framework="tensorflow")
+
output_dir = Path(training_args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# endregion
@@ -330,8 +347,8 @@ def main():
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
- f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
- "\nIgnoring the model labels as a result.",
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels:"
+ f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
)
label_to_id = {v: i for i, v in enumerate(label_list)}
elif not is_regression:
diff --git a/examples/tensorflow/token-classification/run_ner.py b/examples/tensorflow/token-classification/run_ner.py
index e580ed94b061cd..7eecf240cacd7a 100644
--- a/examples/tensorflow/token-classification/run_ner.py
+++ b/examples/tensorflow/token-classification/run_ner.py
@@ -41,6 +41,7 @@
create_optimizer,
set_seed,
)
+from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
@@ -80,8 +81,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -127,37 +130,47 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
label_all_tokens: bool = field(
default=False,
metadata={
- "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
- "one (in which case the other tokens will have a padding index)."
+ "help": (
+ "Whether to put the label for one word on all tokens of generated by that word or just on the "
+ "one (in which case the other tokens will have a padding index)."
+ )
},
)
return_entity_level_metrics: bool = field(
@@ -240,6 +253,10 @@ def main():
# region Argument Parsing
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_ner", model_args, data_args, framework="tensorflow")
# endregion
# region Setup logging
diff --git a/examples/tensorflow/translation/run_translation.py b/examples/tensorflow/translation/run_translation.py
index f81148a4af0b1d..abce256ac9a76d 100644
--- a/examples/tensorflow/translation/run_translation.py
+++ b/examples/tensorflow/translation/run_translation.py
@@ -47,13 +47,13 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version
+from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.19.0.dev0")
+check_min_version("4.20.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@@ -93,8 +93,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
@@ -119,14 +121,15 @@ class DataTrainingArguments:
validation_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
- "(a jsonlines or csv file)."
+ "help": (
+ "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
+ )
},
)
test_file: Optional[str] = field(
default=None,
metadata={
- "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
+ "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
@@ -139,60 +142,76 @@ class DataTrainingArguments:
max_source_length: Optional[int] = field(
default=1024,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total sequence length for target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
- "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
- "during ``evaluate`` and ``predict``."
+ "help": (
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
+ "during ``evaluate`` and ``predict``."
+ )
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
- "help": "Whether to pad all samples to model maximum sentence length. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
- "efficient on GPU but very bad for TPU."
+ "help": (
+ "Whether to pad all samples to model maximum sentence length. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
+ "efficient on GPU but very bad for TPU."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ )
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
- "which is used during ``evaluate`` and ``predict``."
+ "help": (
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``."
+ )
},
)
ignore_pad_token_for_loss: bool = field(
@@ -299,6 +318,10 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_translation", model_args, data_args, framework="tensorflow")
# endregion
# region Logging
diff --git a/model_cards/README.md b/model_cards/README.md
index 4bf6ac6186f33a..b2ee3e25a5d34d 100644
--- a/model_cards/README.md
+++ b/model_cards/README.md
@@ -15,11 +15,7 @@ You can either:
**What if you want to create or update a model card for a model you don't have write access to?**
-In that case, given that we don't have a Pull request system yet on huggingface.co (š¤Æ),
-you can open an issue here, post the card's content, and tag the model author(s) and/or the Hugging Face team.
-
-We might implement a more seamless process at some point, so your early feedback is precious!
-Please let us know of any suggestion.
+In that case, you can open a [Hub pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)! Check out the [announcement](https://huggingface.co/blog/community-update) of this feature for more details š¤.
### What happened to the model cards here?
diff --git a/notebooks/README.md b/notebooks/README.md
index 073d2987027a1c..62f797a10a7ebf 100644
--- a/notebooks/README.md
+++ b/notebooks/README.md
@@ -88,4 +88,4 @@ You can open any page of the documentation as a notebook in colab (there is a bu
## Community notebooks:
-More notebooks developed by the community are available [here](community#community-notebooks).
+More notebooks developed by the community are available [here](https:hf.co/docs/transformers/community#community-notebooks).
diff --git a/setup.py b/setup.py
index e62d7d4d0197c9..abd080edad310d 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
-
+
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
for the post-release and run `make fix-copies` on the main branch as well.
@@ -27,12 +27,13 @@
3. Unpin specific versions from setup.py that use a git install.
-4. Commit these changes with the message: "Release: " and push.
+4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the
+ message: "Release: " and push.
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' "
- Push the tag to git: git push --tags origin main
+ Push the tag to git: git push --tags origin v-release
7. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).
@@ -62,7 +63,7 @@
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
-11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
+11. Run `make post-release` then run `make fix-copies`. If you were on a branch for the release,
you need to go back to main before executing this.
"""
@@ -96,12 +97,14 @@
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
- "black~=22.0",
+ "accelerate>=0.9.0",
+ "black~=22.0,>=22.3",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
"dataclasses",
"datasets",
- "deepspeed>=0.6.0",
+ "deepspeed>=0.6.5",
+ "dill<0.3.5",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
@@ -111,7 +114,7 @@
"ftfy",
"fugashi>=1.0",
"GitPython<3.1.19",
- "hf-doc-builder>=0.2.0",
+ "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.1.0,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
@@ -129,17 +132,18 @@
"packaging>=20.0",
"parameterized",
"phonemizer",
- "protobuf",
+ "protobuf<=3.20.1",
"psutil",
"pyyaml>=5.1",
"pydantic",
"pytest",
"pytest-timeout",
"pytest-xdist",
- "python>=3.6.0",
+ "python>=3.7.0",
"ray[tune]",
"regex!=2019.12.17",
"requests",
+ "rjieba",
"rouge-score",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
@@ -281,6 +285,7 @@ def run(self):
"parameterized",
"psutil",
"datasets",
+ "dill",
"pytest-timeout",
"black",
"sacrebleu",
@@ -288,6 +293,9 @@ def run(self):
"nltk",
"GitPython",
"hf-doc-builder",
+ "protobuf", # Can be removed once we can unpin protobuf
+ "sacremoses",
+ "rjieba"
)
+ extras["retrieval"]
+ extras["modelcreation"]
@@ -365,7 +373,6 @@ def run(self):
"protobuf",
"regex",
"requests",
- "sacremoses",
"sentencepiece",
"torch",
"tokenizers",
@@ -374,7 +381,6 @@ def run(self):
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
- deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
deps["huggingface-hub"],
@@ -383,20 +389,19 @@ def run(self):
deps["pyyaml"], # used for the model cards metadata
deps["regex"], # for OpenAI GPT
deps["requests"], # for downloading models over HTTPS
- deps["sacremoses"], # for XLM
deps["tokenizers"],
deps["tqdm"], # progress bars in model download and training scripts
]
setup(
name="transformers",
- version="4.19.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
- author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
- author_email="thomas@huggingface.co",
- description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
+ version="4.20.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
+ author_email="transformers@huggingface.co",
+ description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
- keywords="NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU",
+ keywords="NLP vision speech deep learning transformer pytorch tensorflow BERT GPT-2 Wav2Vec2 ViT",
license="Apache",
url="https://github.com/huggingface/transformers",
package_dir={"": "src"},
@@ -405,7 +410,7 @@ def run(self):
zip_safe=False,
extras_require=extras,
entry_points={"console_scripts": ["transformers-cli=transformers.commands.transformers_cli:main"]},
- python_requires=">=3.6.0",
+ python_requires=">=3.7.0",
install_requires=install_requires,
classifiers=[
"Development Status :: 5 - Production/Stable",
@@ -415,7 +420,6 @@ def run(self):
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 5695ff57c53b07..1daae6e92034bc 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -22,13 +22,14 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
-__version__ = "4.19.0.dev0"
+__version__ = "4.20.0.dev0"
from typing import TYPE_CHECKING
# Check the dependencies satisfy the minimal versions required.
from . import dependency_versions_check
from .utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_scatter_available,
@@ -155,6 +156,7 @@
"BlenderbotSmallConfig",
"BlenderbotSmallTokenizer",
],
+ "models.bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig"],
"models.bort": [],
"models.byt5": ["ByT5Tokenizer"],
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
@@ -168,8 +170,9 @@
],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"],
- "models.cpm": ["CpmTokenizer"],
+ "models.cpm": [],
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
+ "models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"],
"models.data2vec": [
"DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP",
@@ -197,12 +200,21 @@
"models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"],
"models.encoder_decoder": ["EncoderDecoderConfig"],
"models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"],
- "models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig", "FNetTokenizer"],
+ "models.flava": [
+ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "FlavaConfig",
+ "FlavaImageCodebookConfig",
+ "FlavaImageConfig",
+ "FlavaMultimodalConfig",
+ "FlavaTextConfig",
+ ],
+ "models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"],
"models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"],
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
+ "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
@@ -216,9 +228,18 @@
"LayoutLMv2Processor",
"LayoutLMv2Tokenizer",
],
+ "models.layoutlmv3": [
+ "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "LayoutLMv3Config",
+ "LayoutLMv3FeatureExtractor",
+ "LayoutLMv3Processor",
+ "LayoutLMv3Tokenizer",
+ ],
"models.layoutxlm": ["LayoutXLMProcessor"],
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
+ "models.levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
+ "models.longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config"],
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
@@ -226,6 +247,7 @@
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
+ "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.megatron_gpt2": [],
"models.mluke": [],
@@ -238,6 +260,7 @@
"NystromformerConfig",
],
"models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"],
+ "models.opt": ["OPTConfig"],
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
"models.phobert": ["PhobertTokenizer"],
@@ -274,6 +297,10 @@
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
"models.tapex": ["TapexTokenizer"],
+ "models.trajectory_transformer": [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "TrajectoryTransformerConfig",
+ ],
"models.transfo_xl": [
"TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TransfoXLConfig",
@@ -308,6 +335,10 @@
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
+ "models.wav2vec2_conformer": [
+ "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2ConformerConfig",
+ ],
"models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"],
"models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
"models.wavlm": [
@@ -320,6 +351,7 @@
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
"models.xlm_roberta_xl": ["XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaXLConfig"],
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
+ "models.yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
"models.yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
"onnx": [],
"pipelines": [
@@ -346,6 +378,7 @@
"TextGenerationPipeline",
"TokenClassificationPipeline",
"TranslationPipeline",
+ "VisualQuestionAnsweringPipeline",
"ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline",
"pipeline",
@@ -370,7 +403,7 @@
"TrainerControl",
"TrainerState",
],
- "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
+ "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"],
"training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"],
@@ -411,14 +444,25 @@
}
# sentencepiece-backed objects
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_sentencepiece_objects
+
+ _import_structure["utils.dummy_sentencepiece_objects"] = [
+ name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.albert"].append("AlbertTokenizer")
_import_structure["models.barthez"].append("BarthezTokenizer")
_import_structure["models.bartpho"].append("BartphoTokenizer")
_import_structure["models.bert_generation"].append("BertGenerationTokenizer")
_import_structure["models.big_bird"].append("BigBirdTokenizer")
_import_structure["models.camembert"].append("CamembertTokenizer")
+ _import_structure["models.cpm"].append("CpmTokenizer")
_import_structure["models.deberta_v2"].append("DebertaV2Tokenizer")
+ _import_structure["models.fnet"].append("FNetTokenizer")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizer")
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
@@ -436,16 +480,19 @@
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
_import_structure["models.xlnet"].append("XLNetTokenizer")
-else:
- from .utils import dummy_sentencepiece_objects
-
- _import_structure["utils.dummy_sentencepiece_objects"] = [
- name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
- ]
# tokenizers-backed objects
-if is_tokenizers_available():
- # Fast tokenizers
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tokenizers_objects
+
+ _import_structure["utils.dummy_tokenizers_objects"] = [
+ name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ # Fast tokenizers structure
_import_structure["models.albert"].append("AlbertTokenizerFast")
_import_structure["models.bart"].append("BartTokenizerFast")
_import_structure["models.barthez"].append("BarthezTokenizerFast")
@@ -453,9 +500,11 @@
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
+ _import_structure["models.bloom"].append("BloomTokenizerFast")
_import_structure["models.camembert"].append("CamembertTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
+ _import_structure["models.cpm"].append("CpmTokenizerFast")
_import_structure["models.deberta"].append("DebertaTokenizerFast")
_import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast")
_import_structure["models.distilbert"].append("DistilBertTokenizerFast")
@@ -466,9 +515,11 @@
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.funnel"].append("FunnelTokenizerFast")
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
+ _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast")
_import_structure["models.herbert"].append("HerbertTokenizerFast")
_import_structure["models.layoutlm"].append("LayoutLMTokenizerFast")
_import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast")
+ _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
_import_structure["models.led"].append("LEDTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast")
@@ -494,43 +545,56 @@
_import_structure["models.xlnet"].append("XLNetTokenizerFast")
_import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
-else:
- from .utils import dummy_tokenizers_objects
-
- _import_structure["utils.dummy_tokenizers_objects"] = [
- name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
- ]
-if is_sentencepiece_available() and is_tokenizers_available():
- _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
-else:
+try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_sentencepiece_and_tokenizers_objects
_import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
]
+else:
+ _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
# Speech-specific objects
-if is_speech_available():
- _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
-else:
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_speech_objects
_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
-
-if is_sentencepiece_available() and is_speech_available():
- _import_structure["models.speech_to_text"].append("Speech2TextProcessor")
else:
+ _import_structure["models.mctct"].append("MCTCTFeatureExtractor")
+ _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
+
+try:
+ if not (is_sentencepiece_available() and is_speech_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
from .utils import dummy_sentencepiece_and_speech_objects
_import_structure["utils.dummy_sentencepiece_and_speech_objects"] = [
name for name in dir(dummy_sentencepiece_and_speech_objects) if not name.startswith("_")
]
+else:
+ _import_structure["models.speech_to_text"].append("Speech2TextProcessor")
# Vision-specific objects
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_vision_objects
+
+ _import_structure["utils.dummy_vision_objects"] = [
+ name for name in dir(dummy_vision_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].append("BeitFeatureExtractor")
_import_structure["models.clip"].append("CLIPFeatureExtractor")
@@ -539,11 +603,12 @@
_import_structure["models.deit"].append("DeiTFeatureExtractor")
_import_structure["models.detr"].append("DetrFeatureExtractor")
_import_structure["models.dpt"].append("DPTFeatureExtractor")
+ _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor"])
_import_structure["models.glpn"].append("GLPNFeatureExtractor")
_import_structure["models.imagegpt"].append("ImageGPTFeatureExtractor")
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
- _import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
- _import_structure["models.layoutxlm"].append("LayoutXLMProcessor")
+ _import_structure["models.layoutlmv3"].append("LayoutLMv3FeatureExtractor")
+ _import_structure["models.levit"].append("LevitFeatureExtractor")
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor")
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
_import_structure["models.poolformer"].append("PoolFormerFeatureExtractor")
@@ -551,15 +616,19 @@
_import_structure["models.vilt"].append("ViltFeatureExtractor")
_import_structure["models.vilt"].append("ViltProcessor")
_import_structure["models.vit"].append("ViTFeatureExtractor")
-else:
- from .utils import dummy_vision_objects
-
- _import_structure["utils.dummy_vision_objects"] = [
- name for name in dir(dummy_vision_objects) if not name.startswith("_")
- ]
+ _import_structure["models.yolos"].append("YolosFeatureExtractor")
# Timm-backed objects
-if is_timm_available() and is_vision_available():
+try:
+ if not (is_timm_available() and is_vision_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_timm_objects
+
+ _import_structure["utils.dummy_timm_objects"] = [
+ name for name in dir(dummy_timm_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -569,14 +638,17 @@
"DetrPreTrainedModel",
]
)
-else:
- from .utils import dummy_timm_objects
- _import_structure["utils.dummy_timm_objects"] = [
- name for name in dir(dummy_timm_objects) if not name.startswith("_")
- ]
+try:
+ if not is_scatter_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_scatter_objects
-if is_scatter_available():
+ _import_structure["utils.dummy_scatter_objects"] = [
+ name for name in dir(dummy_scatter_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["models.tapas"].extend(
[
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -588,16 +660,17 @@
"load_tf_weights_in_tapas",
]
)
-else:
- from .utils import dummy_scatter_objects
-
- _import_structure["utils.dummy_scatter_objects"] = [
- name for name in dir(dummy_scatter_objects) if not name.startswith("_")
- ]
# PyTorch-backed objects
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_pt_objects
+
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
+else:
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
@@ -636,6 +709,7 @@
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
+ "TypicalLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria",
@@ -686,6 +760,7 @@
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
@@ -711,6 +786,7 @@
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForVision2Seq",
+ "AutoModelForVisualQuestionAnswering",
"AutoModelWithLMHead",
]
)
@@ -788,6 +864,16 @@
"BigBirdPegasusPreTrainedModel",
]
)
+ _import_structure["models.bloom"].extend(
+ [
+ "BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ "BloomForSequenceClassification",
+ "BloomForTokenClassification",
+ ]
+ )
_import_structure["models.blenderbot"].extend(
[
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -871,6 +957,14 @@
"CTRLPreTrainedModel",
]
)
+ _import_structure["models.cvt"].extend(
+ [
+ "CVT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "CvtForImageClassification",
+ "CvtModel",
+ "CvtPreTrainedModel",
+ ]
+ )
_import_structure["models.data2vec"].extend(
[
"DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -911,6 +1005,7 @@
[
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaV2ForMaskedLM",
+ "DebertaV2ForMultipleChoice",
"DebertaV2ForQuestionAnswering",
"DebertaV2ForSequenceClassification",
"DebertaV2ForTokenClassification",
@@ -1000,6 +1095,18 @@
"FlaubertWithLMHeadModel",
]
)
+ _import_structure["models.flava"].extend(
+ [
+ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "FlavaForPreTraining",
+ "FlavaImageCodebook",
+ "FlavaImageModel",
+ "FlavaModel",
+ "FlavaMultimodalModel",
+ "FlavaPreTrainedModel",
+ "FlavaTextModel",
+ ]
+ )
_import_structure["models.fnet"].extend(
[
"FNET_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1061,6 +1168,15 @@
"load_tf_weights_in_gpt_neo",
]
)
+ _import_structure["models.gpt_neox"].extend(
+ [
+ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GPTNeoXForCausalLM",
+ "GPTNeoXLayer",
+ "GPTNeoXModel",
+ "GPTNeoXPreTrainedModel",
+ ]
+ )
_import_structure["models.gptj"].extend(
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1122,6 +1238,16 @@
"LayoutLMv2PreTrainedModel",
]
)
+ _import_structure["models.layoutlmv3"].extend(
+ [
+ "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LayoutLMv3ForQuestionAnswering",
+ "LayoutLMv3ForSequenceClassification",
+ "LayoutLMv3ForTokenClassification",
+ "LayoutLMv3Model",
+ "LayoutLMv3PreTrainedModel",
+ ]
+ )
_import_structure["models.led"].extend(
[
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1132,6 +1258,15 @@
"LEDPreTrainedModel",
]
)
+ _import_structure["models.levit"].extend(
+ [
+ "LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LevitForImageClassification",
+ "LevitForImageClassificationWithTeacher",
+ "LevitModel",
+ "LevitPreTrainedModel",
+ ]
+ )
_import_structure["models.longformer"].extend(
[
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1145,6 +1280,15 @@
"LongformerSelfAttention",
]
)
+ _import_structure["models.longt5"].extend(
+ [
+ "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LongT5EncoderModel",
+ "LongT5ForConditionalGeneration",
+ "LongT5Model",
+ "LongT5PreTrainedModel",
+ ]
+ )
_import_structure["models.luke"].extend(
[
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1194,6 +1338,14 @@
"MBartPreTrainedModel",
]
)
+ _import_structure["models.mctct"].extend(
+ [
+ "MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MCTCTForCTC",
+ "MCTCTModel",
+ "MCTCTPreTrainedModel",
+ ]
+ )
_import_structure["models.megatron_bert"].extend(
[
"MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1264,6 +1416,14 @@
"load_tf_weights_in_openai_gpt",
]
)
+ _import_structure["models.opt"].extend(
+ [
+ "OPT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OPTForCausalLM",
+ "OPTModel",
+ "OPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
)
@@ -1460,6 +1620,7 @@
_import_structure["models.splinter"].extend(
[
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "SplinterForPreTraining",
"SplinterForQuestionAnswering",
"SplinterLayer",
"SplinterModel",
@@ -1498,6 +1659,13 @@
"load_tf_weights_in_t5",
]
)
+ _import_structure["models.trajectory_transformer"].extend(
+ [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ ]
+ )
_import_structure["models.transfo_xl"].extend(
[
"TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1600,6 +1768,18 @@
"Wav2Vec2PreTrainedModel",
]
)
+ _import_structure["models.wav2vec2_conformer"].extend(
+ [
+ "WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+ ]
+ )
_import_structure["models.wavlm"].extend(
[
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1681,6 +1861,14 @@
"load_tf_weights_in_xlnet",
]
)
+ _import_structure["models.yolos"].extend(
+ [
+ "YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "YolosForObjectDetection",
+ "YolosModel",
+ "YolosPreTrainedModel",
+ ]
+ )
_import_structure["models.yoso"].extend(
[
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1710,13 +1898,16 @@
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
-else:
- from .utils import dummy_pt_objects
-
- _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
# TensorFlow-backed objects
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tf_objects
+
+ _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
+else:
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
@@ -1762,6 +1953,7 @@
[
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
@@ -1780,6 +1972,7 @@
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice",
+ "TFAutoModelForNextSentencePrediction",
"TFAutoModelForPreTraining",
"TFAutoModelForQuestionAnswering",
"TFAutoModelForSeq2SeqLM",
@@ -1865,6 +2058,14 @@
"TFCTRLPreTrainedModel",
]
)
+ _import_structure["models.data2vec"].extend(
+ [
+ "TFData2VecVisionForImageClassification",
+ "TFData2VecVisionForSemanticSegmentation",
+ "TFData2VecVisionModel",
+ "TFData2VecVisionPreTrainedModel",
+ ]
+ )
_import_structure["models.deberta"].extend(
[
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2060,6 +2261,13 @@
"TFOpenAIGPTPreTrainedModel",
]
)
+ _import_structure["models.opt"].extend(
+ [
+ "TFOPTForCausalLM",
+ "TFOPTModel",
+ "TFOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
@@ -2121,6 +2329,15 @@
"TFSpeech2TextPreTrainedModel",
]
)
+ _import_structure["models.swin"].extend(
+ [
+ "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSwinForImageClassification",
+ "TFSwinForMaskedImageModeling",
+ "TFSwinModel",
+ "TFSwinPreTrainedModel",
+ ]
+ )
_import_structure["models.t5"].extend(
[
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2215,13 +2432,18 @@
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"]
-else:
- from .utils import dummy_tf_objects
-
- _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
# FLAX-backed objects
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_flax_objects
+
+ _import_structure["utils.dummy_flax_objects"] = [
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
+ ]
+else:
_import_structure["generation_flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
@@ -2290,7 +2512,6 @@
"FlaxBartPreTrainedModel",
]
)
-
_import_structure["models.beit"].extend(
[
"FlaxBeitForImageClassification",
@@ -2299,8 +2520,10 @@
"FlaxBeitPreTrainedModel",
]
)
+
_import_structure["models.bert"].extend(
[
+ "FlaxBertForCausalLM",
"FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction",
@@ -2314,6 +2537,7 @@
)
_import_structure["models.big_bird"].extend(
[
+ "FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining",
@@ -2357,6 +2581,7 @@
)
_import_structure["models.electra"].extend(
[
+ "FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
@@ -2373,6 +2598,9 @@
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
+ _import_structure["models.longt5"].extend(
+ ["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
+ )
_import_structure["models.marian"].extend(
[
"FlaxMarianModel",
@@ -2390,6 +2618,13 @@
]
)
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
+ _import_structure["models.opt"].extend(
+ [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
[
"FlaxPegasusForConditionalGeneration",
@@ -2399,6 +2634,7 @@
)
_import_structure["models.roberta"].extend(
[
+ "FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
@@ -2444,12 +2680,6 @@
"FlaxXLMRobertaModel",
]
)
-else:
- from .utils import dummy_flax_objects
-
- _import_structure["utils.dummy_flax_objects"] = [
- name for name in dir(dummy_flax_objects) if not name.startswith("_")
- ]
# Direct imports for type-checking
@@ -2553,6 +2783,7 @@
BlenderbotSmallConfig,
BlenderbotSmallTokenizer,
)
+ from .models.bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
from .models.byt5 import ByT5Tokenizer
from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .models.canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig, CanineTokenizer
@@ -2565,8 +2796,8 @@
)
from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer
from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
- from .models.cpm import CpmTokenizer
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer
+ from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig
from .models.data2vec import (
DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -2595,12 +2826,21 @@
from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer
from .models.encoder_decoder import EncoderDecoderConfig
from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer
- from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig, FNetTokenizer
+ from .models.flava import (
+ FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+ )
+ from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig
from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
+ from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .models.herbert import HerbertTokenizer
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
@@ -2614,15 +2854,25 @@
LayoutLMv2Processor,
LayoutLMv2Tokenizer,
)
+ from .models.layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ LayoutLMv3Config,
+ LayoutLMv3FeatureExtractor,
+ LayoutLMv3Processor,
+ LayoutLMv3Tokenizer,
+ )
from .models.layoutxlm import LayoutXLMProcessor
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
+ from .models.levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
+ from .models.longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.marian import MarianConfig
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .models.mbart import MBartConfig
+ from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from .models.mmbt import MMBTConfig
from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer
@@ -2630,6 +2880,7 @@
from .models.mt5 import MT5Config
from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig
from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer
+ from .models.opt import OPTConfig
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
from .models.phobert import PhobertTokenizer
@@ -2663,6 +2914,10 @@
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
from .models.tapex import TapexTokenizer
+ from .models.trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ TrajectoryTransformerConfig,
+ )
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig,
@@ -2687,6 +2942,7 @@
Wav2Vec2Processor,
Wav2Vec2Tokenizer,
)
+ from .models.wav2vec2_conformer import WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2ConformerConfig
from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .models.wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
@@ -2696,6 +2952,7 @@
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .models.xlm_roberta_xl import XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaXLConfig
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
+ from .models.yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
from .models.yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
# Pipelines
@@ -2723,6 +2980,7 @@
TextGenerationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
+ VisualQuestionAnsweringPipeline,
ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline,
pipeline,
@@ -2750,7 +3008,7 @@
TrainerControl,
TrainerState,
)
- from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
+ from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
@@ -2791,14 +3049,21 @@
logging,
)
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_sentencepiece_objects import *
+ else:
from .models.albert import AlbertTokenizer
from .models.barthez import BarthezTokenizer
from .models.bartpho import BartphoTokenizer
from .models.bert_generation import BertGenerationTokenizer
from .models.big_bird import BigBirdTokenizer
from .models.camembert import CamembertTokenizer
+ from .models.cpm import CpmTokenizer
from .models.deberta_v2 import DebertaV2Tokenizer
+ from .models.fnet import FNetTokenizer
from .models.layoutxlm import LayoutXLMTokenizer
from .models.m2m_100 import M2M100Tokenizer
from .models.marian import MarianTokenizer
@@ -2815,10 +3080,14 @@
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
from .models.xlnet import XLNetTokenizer
- else:
- from .utils.dummy_sentencepiece_objects import *
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_tokenizers_objects import *
+ else:
+ # Fast tokenizers imports
from .models.albert import AlbertTokenizerFast
from .models.bart import BartTokenizerFast
from .models.barthez import BarthezTokenizerFast
@@ -2826,9 +3095,11 @@
from .models.big_bird import BigBirdTokenizerFast
from .models.blenderbot import BlenderbotTokenizerFast
from .models.blenderbot_small import BlenderbotSmallTokenizerFast
+ from .models.bloom import BloomTokenizerFast
from .models.camembert import CamembertTokenizerFast
from .models.clip import CLIPTokenizerFast
from .models.convbert import ConvBertTokenizerFast
+ from .models.cpm import CpmTokenizerFast
from .models.deberta import DebertaTokenizerFast
from .models.deberta_v2 import DebertaV2TokenizerFast
from .models.distilbert import DistilBertTokenizerFast
@@ -2837,9 +3108,11 @@
from .models.fnet import FNetTokenizerFast
from .models.funnel import FunnelTokenizerFast
from .models.gpt2 import GPT2TokenizerFast
+ from .models.gpt_neox import GPTNeoXTokenizerFast
from .models.herbert import HerbertTokenizerFast
from .models.layoutlm import LayoutLMTokenizerFast
from .models.layoutlmv2 import LayoutLMv2TokenizerFast
+ from .models.layoutlmv3 import LayoutLMv3TokenizerFast
from .models.layoutxlm import LayoutXLMTokenizerFast
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
@@ -2865,25 +3138,37 @@
from .models.xlnet import XLNetTokenizerFast
from .tokenization_utils_fast import PreTrainedTokenizerFast
+ try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummies_sentencepiece_and_tokenizers_objects import *
else:
- from .utils.dummy_tokenizers_objects import *
-
- if is_sentencepiece_available() and is_tokenizers_available():
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
- else:
- from .utils.dummies_sentencepiece_and_tokenizers_objects import *
- if is_speech_available():
- from .models.speech_to_text import Speech2TextFeatureExtractor
- else:
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
-
- if is_speech_available() and is_sentencepiece_available():
- from .models.speech_to_text import Speech2TextProcessor
else:
+ from .models.mctct import MCTCTFeatureExtractor
+ from .models.speech_to_text import Speech2TextFeatureExtractor
+
+ try:
+ if not (is_speech_available() and is_sentencepiece_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
from .utils.dummy_sentencepiece_and_speech_objects import *
+ else:
+ from .models.speech_to_text import Speech2TextProcessor
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_vision_objects import *
+ else:
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor
from .models.clip import CLIPFeatureExtractor, CLIPProcessor
@@ -2891,21 +3176,27 @@
from .models.deit import DeiTFeatureExtractor
from .models.detr import DetrFeatureExtractor
from .models.dpt import DPTFeatureExtractor
+ from .models.flava import FlavaFeatureExtractor, FlavaProcessor
from .models.glpn import GLPNFeatureExtractor
from .models.imagegpt import ImageGPTFeatureExtractor
- from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
- from .models.layoutxlm import LayoutXLMProcessor
+ from .models.layoutlmv2 import LayoutLMv2FeatureExtractor
+ from .models.layoutlmv3 import LayoutLMv3FeatureExtractor
+ from .models.levit import LevitFeatureExtractor
from .models.maskformer import MaskFormerFeatureExtractor
from .models.perceiver import PerceiverFeatureExtractor
from .models.poolformer import PoolFormerFeatureExtractor
from .models.segformer import SegformerFeatureExtractor
from .models.vilt import ViltFeatureExtractor, ViltProcessor
from .models.vit import ViTFeatureExtractor
- else:
- from .utils.dummy_vision_objects import *
+ from .models.yolos import YolosFeatureExtractor
# Modeling
- if is_timm_available() and is_vision_available():
+ try:
+ if not (is_timm_available() and is_vision_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_timm_objects import *
+ else:
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
@@ -2913,10 +3204,13 @@
DetrModel,
DetrPreTrainedModel,
)
- else:
- from .utils.dummy_timm_objects import *
- if is_scatter_available():
+ try:
+ if not is_scatter_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_scatter_objects import *
+ else:
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
@@ -2926,10 +3220,13 @@
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
- else:
- from .utils.dummy_scatter_objects import *
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_pt_objects import *
+ else:
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
@@ -2967,6 +3264,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
+ TypicalLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
@@ -2976,6 +3274,8 @@
)
from .generation_utils import top_k_top_p_filtering
from .modeling_utils import PreTrainedModel
+
+ # PyTorch model imports
from .models.albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
AlbertForMaskedLM,
@@ -3011,6 +3311,7 @@
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
+ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
@@ -3036,6 +3337,7 @@
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
+ AutoModelForVisualQuestionAnswering,
AutoModelWithLMHead,
)
from .models.bart import (
@@ -3114,6 +3416,14 @@
BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
)
+ from .models.bloom import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
@@ -3167,6 +3477,12 @@
CTRLModel,
CTRLPreTrainedModel,
)
+ from .models.cvt import (
+ CVT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ CvtForImageClassification,
+ CvtModel,
+ CvtPreTrainedModel,
+ )
from .models.data2vec import (
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST,
DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -3202,6 +3518,7 @@
from .models.deberta_v2 import (
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
@@ -3276,6 +3593,16 @@
FlaubertModel,
FlaubertWithLMHeadModel,
)
+ from .models.flava import (
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaPreTrainedModel,
+ FlavaTextModel,
+ )
from .models.fnet import (
FNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FNetForMaskedLM,
@@ -3327,6 +3654,13 @@
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
)
+ from .models.gpt_neox import (
+ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GPTNeoXForCausalLM,
+ GPTNeoXLayer,
+ GPTNeoXModel,
+ GPTNeoXPreTrainedModel,
+ )
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
@@ -3376,6 +3710,14 @@
LayoutLMv2Model,
LayoutLMv2PreTrainedModel,
)
+ from .models.layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ LayoutLMv3PreTrainedModel,
+ )
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
@@ -3384,6 +3726,13 @@
LEDModel,
LEDPreTrainedModel,
)
+ from .models.levit import (
+ LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ LevitPreTrainedModel,
+ )
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
@@ -3395,6 +3744,13 @@
LongformerPreTrainedModel,
LongformerSelfAttention,
)
+ from .models.longt5 import (
+ LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ LongT5PreTrainedModel,
+ )
from .models.luke import (
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
LukeForEntityClassification,
@@ -3434,6 +3790,7 @@
MBartModel,
MBartPreTrainedModel,
)
+ from .models.mctct import MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST, MCTCTForCTC, MCTCTModel, MCTCTPreTrainedModel
from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronBertForCausalLM,
@@ -3494,6 +3851,7 @@
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
+ from .models.opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel
from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
@@ -3656,6 +4014,7 @@
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel
from .models.splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
@@ -3687,6 +4046,11 @@
T5PreTrainedModel,
load_tf_weights_in_t5,
)
+ from .models.trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TrajectoryTransformerModel,
+ TrajectoryTransformerPreTrainedModel,
+ )
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,
@@ -3769,6 +4133,16 @@
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
+ from .models.wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerPreTrainedModel,
+ )
from .models.wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
@@ -3831,6 +4205,12 @@
XLNetPreTrainedModel,
load_tf_weights_in_xlnet,
)
+ from .models.yolos import (
+ YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
+ YolosForObjectDetection,
+ YolosModel,
+ YolosPreTrainedModel,
+ )
from .models.yoso import (
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
YosoForMaskedLM,
@@ -3861,12 +4241,16 @@
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer
- else:
- from .utils.dummy_pt_objects import *
# TensorFlow
- if is_tf_available():
-
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ # Import the same objects as dummies to get them in the namespace.
+ # They will raise an import error if the user tries to instantiate / use them.
+ from .utils.dummy_tf_objects import *
+ else:
from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
# Benchmarks
@@ -3897,6 +4281,8 @@
TFLayoutLMPreTrainedModel,
)
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list
+
+ # TensorFlow model imports
from .models.albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM,
@@ -3912,6 +4298,7 @@
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
@@ -3930,6 +4317,7 @@
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
+ TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
@@ -4002,6 +4390,12 @@
TFCTRLModel,
TFCTRLPreTrainedModel,
)
+ from .models.data2vec import (
+ TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
+ TFData2VecVisionModel,
+ TFData2VecVisionPreTrainedModel,
+ )
from .models.deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
@@ -4154,6 +4548,7 @@
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
+ from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.rembert import (
@@ -4198,6 +4593,13 @@
TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel,
)
+ from .models.swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinPreTrainedModel,
+ )
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
@@ -4269,13 +4671,14 @@
# Trainer
from .trainer_tf import TFTrainer
- else:
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
- from .utils.dummy_tf_objects import *
-
- if is_flax_available():
-
+ from .utils.dummy_flax_objects import *
+ else:
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
@@ -4288,6 +4691,8 @@
FlaxTopPLogitsWarper,
)
from .modeling_flax_utils import FlaxPreTrainedModel
+
+ # Flax model imports
from .models.albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
@@ -4340,6 +4745,7 @@
FlaxBeitPreTrainedModel,
)
from .models.bert import (
+ FlaxBertForCausalLM,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
@@ -4351,6 +4757,7 @@
FlaxBertPreTrainedModel,
)
from .models.big_bird import (
+ FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
@@ -4388,6 +4795,7 @@
FlaxDistilBertPreTrainedModel,
)
from .models.electra import (
+ FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
@@ -4401,6 +4809,7 @@
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
+ from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
from .models.mbart import (
FlaxMBartForConditionalGeneration,
@@ -4410,8 +4819,10 @@
FlaxMBartPreTrainedModel,
)
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
+ from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
+ FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
@@ -4449,10 +4860,6 @@
FlaxXLMRobertaForTokenClassification,
FlaxXLMRobertaModel,
)
- else:
- # Import the same objects as dummies to get them in the namespace.
- # They will raise an import error if the user tries to instantiate / use them.
- from .utils.dummy_flax_objects import *
else:
import sys
diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py
index 8569c6e324e3da..7f95e4b40b7cd3 100644
--- a/src/transformers/benchmark/benchmark.py
+++ b/src/transformers/benchmark/benchmark.py
@@ -96,7 +96,8 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_MAPPING[config.__class__](config)
@@ -151,7 +152,8 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
@@ -230,7 +232,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
if self.args.is_tpu:
# tpu
raise NotImplementedError(
- "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `--no-memory` or `args.memory=False`"
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
+ " `--no-memory` or `args.memory=False`"
)
elif self.args.is_gpu:
if not is_py3nvml_available():
@@ -241,7 +244,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
memory = "N/A"
else:
logger.info(
- "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
+ " on the same GPU."
)
# init nvml
nvml.nvmlInit()
diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py
index dbdf9d8a36734b..57af2481ef2cab 100644
--- a/src/transformers/benchmark/benchmark_args.py
+++ b/src/transformers/benchmark/benchmark_args.py
@@ -54,7 +54,8 @@ def __init__(self, **kwargs):
positive_arg = deprecated_arg[3:]
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
logger.warning(
- f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
+ f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
+ f" {positive_arg}={kwargs[positive_arg]}"
)
self.torchscript = kwargs.pop("torchscript", self.torchscript)
diff --git a/src/transformers/benchmark/benchmark_args_tf.py b/src/transformers/benchmark/benchmark_args_tf.py
index 7ec5054cb37c77..8f3a9cea946555 100644
--- a/src/transformers/benchmark/benchmark_args_tf.py
+++ b/src/transformers/benchmark/benchmark_args_tf.py
@@ -51,7 +51,8 @@ def __init__(self, **kwargs):
positive_arg = deprecated_arg[3:]
kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
logger.warning(
- f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
+ f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
+ f" {positive_arg}={kwargs[positive_arg]}"
)
self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
self.device_idx = kwargs.pop("device_idx", self.device_idx)
diff --git a/src/transformers/benchmark/benchmark_args_utils.py b/src/transformers/benchmark/benchmark_args_utils.py
index b2f76f809f189c..d9233906d281c9 100644
--- a/src/transformers/benchmark/benchmark_args_utils.py
+++ b/src/transformers/benchmark/benchmark_args_utils.py
@@ -43,7 +43,10 @@ class BenchmarkArguments:
models: List[str] = list_field(
default=[],
metadata={
- "help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
+ "help": (
+ "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
+ " of all available models"
+ )
},
)
@@ -87,7 +90,11 @@ class BenchmarkArguments:
multi_process: bool = field(
default=True,
metadata={
- "help": "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled for debugging / testing and on TPU."
+ "help": (
+ "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
+ " multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
+ " for debugging / testing and on TPU."
+ )
},
)
inference_time_csv_file: str = field(
@@ -118,7 +125,10 @@ class BenchmarkArguments:
only_pretrain_model: bool = field(
default=False,
metadata={
- "help": "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain model weights."
+ "help": (
+ "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
+ " model weights."
+ )
},
)
@@ -138,9 +148,10 @@ def to_json_string(self):
@property
def model_names(self):
- assert (
- len(self.models) > 0
- ), "Please make sure you provide at least one model name / model identifier, *e.g.* `--models bert-base-cased` or `args.models = ['bert-base-cased']."
+ assert len(self.models) > 0, (
+ "Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
+ " bert-base-cased` or `args.models = ['bert-base-cased']."
+ )
return self.models
@property
diff --git a/src/transformers/benchmark/benchmark_tf.py b/src/transformers/benchmark/benchmark_tf.py
index 0eb0db64a8d6f4..b5fd4b71b562a2 100644
--- a/src/transformers/benchmark/benchmark_tf.py
+++ b/src/transformers/benchmark/benchmark_tf.py
@@ -140,7 +140,8 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_MAPPING[config.__class__](config)
@@ -184,7 +185,8 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
model = model_cls(config)
except ImportError:
raise ImportError(
- f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
)
else:
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
@@ -239,15 +241,17 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
with self.args.strategy.scope():
try:
if self.args.trace_memory_line_by_line:
- assert (
- self.args.eager_mode
- ), "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory consumption line by line."
+ assert self.args.eager_mode, (
+ "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
+ " consumption line by line."
+ )
trace = start_memory_tracing("transformers")
if self.args.is_tpu:
# tpu
raise NotImplementedError(
- "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.memory=False`"
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking"
+ " with `args.memory=False`"
)
elif self.args.is_gpu:
# gpu
@@ -259,7 +263,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
memory = "N/A"
else:
logger.info(
- "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes"
+ " running on the same GPU."
)
# init nvml
nvml.nvmlInit()
@@ -274,7 +279,8 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
# cpu
if self.args.trace_memory_line_by_line:
logger.info(
- "When enabling line by line tracing, the max peak memory for CPU is inaccurate in TensorFlow."
+ "When enabling line by line tracing, the max peak memory for CPU is inaccurate in"
+ " TensorFlow."
)
memory = None
else:
diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py
index 7e738bb601cf18..36fe5eb116cbef 100644
--- a/src/transformers/benchmark/benchmark_utils.py
+++ b/src/transformers/benchmark/benchmark_utils.py
@@ -379,7 +379,7 @@ def start_memory_tracing(
devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
nvml.nvmlShutdown()
except (OSError, nvml.NVMLError):
- logger.warning("Error while initializing communication with GPU. " "We won't perform GPU memory tracing.")
+ logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
log_gpu = False
else:
log_gpu = is_torch_available() or is_tf_available()
@@ -626,7 +626,8 @@ def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig =
if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
logger.warning(
- "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
+ "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
+ " flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
)
self._print_fn = None
@@ -732,7 +733,8 @@ def run(self):
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if self.args.is_tpu:
self.print_fn(
- "TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured."
+ "TPU was used for inference. Note that the time after compilation stabilized (after ~10"
+ " inferences model.forward(..) calls) was measured."
)
if self.args.memory:
@@ -751,7 +753,8 @@ def run(self):
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if self.args.is_tpu:
self.print_fn(
- "TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured."
+ "TPU was used for training. Note that the time after compilation stabilized (after ~10 train"
+ " loss=model.forward(...) + loss.backward() calls) was measured."
)
if self.args.memory:
diff --git a/src/transformers/commands/add_new_model.py b/src/transformers/commands/add_new_model.py
index 276032eefe6327..85d053a14873a3 100644
--- a/src/transformers/commands/add_new_model.py
+++ b/src/transformers/commands/add_new_model.py
@@ -15,6 +15,7 @@
import json
import os
import shutil
+import warnings
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
@@ -54,6 +55,11 @@ def __init__(self, testing: bool, testing_file: str, path=None, *args):
self._path = path
def run(self):
+ warnings.warn(
+ "The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. "
+ "It is not actively maintained anymore, so might give a result that won't pass all tests and quality "
+ "checks, you should use `transformers-cli add-new-model-like` instead."
+ )
if not _has_cookiecutter:
raise ImportError(
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
@@ -102,10 +108,10 @@ def run(self):
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
os.makedirs(model_dir, exist_ok=True)
- os.makedirs(f"{path_to_transformer_root}/tests/{lowercase_model_name}", exist_ok=True)
+ os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True)
# Tests require submodules as they have parent imports
- with open(f"{path_to_transformer_root}/tests/{lowercase_model_name}/__init__.py", "w"):
+ with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"):
pass
shutil.move(
@@ -136,7 +142,7 @@ def remove_copy_lines(path):
shutil.move(
f"{directory}/test_modeling_{lowercase_model_name}.py",
- f"{path_to_transformer_root}/tests/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py",
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_{lowercase_model_name}.py")
@@ -153,7 +159,7 @@ def remove_copy_lines(path):
shutil.move(
f"{directory}/test_modeling_tf_{lowercase_model_name}.py",
- f"{path_to_transformer_root}/tests/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py",
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py")
@@ -170,7 +176,7 @@ def remove_copy_lines(path):
shutil.move(
f"{directory}/test_modeling_flax_{lowercase_model_name}.py",
- f"{path_to_transformer_root}/tests/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py",
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py",
)
else:
os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py")
diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py
index 30a8229b57ab97..2c22acc26d4de0 100644
--- a/src/transformers/commands/add_new_model_like.py
+++ b/src/transformers/commands/add_new_model_like.py
@@ -18,6 +18,7 @@
import re
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
+from datetime import date
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
@@ -32,6 +33,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+CURRENT_YEAR = date.today().year
TRANSFORMERS_PATH = Path(__file__).parent.parent
REPO_PATH = TRANSFORMERS_PATH.parent.parent
@@ -421,6 +423,7 @@ def duplicate_module(
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
+ content = re.sub("# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
objects = parse_module_content(content)
# Loop and treat all objects
@@ -554,7 +557,7 @@ def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) ->
]
test_files = filter_framework_files(test_files, frameworks=frameworks)
# Add the test directory
- test_files = [REPO_PATH / "tests" / module_name / f for f in test_files]
+ test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
# Filter by existing files
test_files = [f for f in test_files if f.exists()]
@@ -766,7 +769,9 @@ def clean_frameworks_in_init(
return
remove_pattern = "|".join(to_remove)
- re_conditional_imports = re.compile(rf"^\s*if is_({remove_pattern})_available\(\):\s*$")
+ re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
+ re_try = re.compile(r"\s*try:")
+ re_else = re.compile(r"\s*else:")
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
with open(init_file, "r", encoding="utf-8") as f:
@@ -776,11 +781,15 @@ def clean_frameworks_in_init(
new_lines = []
idx = 0
while idx < len(lines):
- # Conditional imports
- if re_conditional_imports.search(lines[idx]) is not None:
+ # Conditional imports in try-except-else blocks
+ if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
+ # Remove the preceding `try:`
+ new_lines.pop()
idx += 1
- while is_empty_line(lines[idx]):
+ # Iterate until `else:`
+ while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
idx += 1
+ idx += 1
indent = find_indent(lines[idx])
while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
idx += 1
@@ -790,6 +799,7 @@ def clean_frameworks_in_init(
for framework in to_remove:
line = line.replace(f", is_{framework}_available", "")
line = line.replace(f"is_{framework}_available, ", "")
+ line = line.replace(f"is_{framework}_available,", "")
line = line.replace(f"is_{framework}_available", "")
if len(line.strip()) > 0:
@@ -836,11 +846,11 @@ def add_model_to_main_init(
while idx < len(lines):
if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
framework = None
- elif lines[idx].lstrip().startswith("if is_torch_available"):
+ elif lines[idx].lstrip().startswith("if not is_torch_available"):
framework = "pt"
- elif lines[idx].lstrip().startswith("if is_tf_available"):
+ elif lines[idx].lstrip().startswith("if not is_tf_available"):
framework = "tf"
- elif lines[idx].lstrip().startswith("if is_flax_available"):
+ elif lines[idx].lstrip().startswith("if not is_flax_available"):
framework = "flax"
# Skip if we are in a framework not wanted.
@@ -1055,6 +1065,7 @@ def duplicate_doc_file(
with open(doc_file, "r", encoding="utf-8") as f:
content = f.read()
+ content = re.sub(" [batch_size, seq_length, num_heads, 3 x head_dim]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
+ value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ # [batch_size, head_dim, q_length, k_length]
+ output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
+
+ # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
+ query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
+
+ # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
+ key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
+
+ # slice alibi tensor until the query length
+ sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
+
+ # Raw attention scores. [batch_size * num_heads, q_length, k_length]
+ beta = 1.0 / self.layer_number
+
+ matmul_result = torch.baddbmm(
+ sliced_alibi,
+ query_layer.transpose(1, 0),
+ key_layer.transpose(1, 0).transpose(1, 2),
+ beta=beta,
+ alpha=(1.0 / self.norm_factor),
+ )
+
+ # change view to [batch_size, num_heads, q_length, k_length]
+ attention_scores = matmul_result.view(*output_size)
+
+ # attention scores and attention mask [b, np, sq, sk]
+ max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
+ value_layer.dtype
+ )
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # context layer shape: [batch_size, num_heads, q_length, head_dim]
+ output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [k_length, batch_size x num_heads, head_dim]
+ value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
+
+ # change view [batch_size x num_heads, q_length, k_length]
+ attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = context_layer.view(*output_size)
+
+ # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
+
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # Output. [q_length, batch_size, hidden_size]
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = context_layer.shape[-1] / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + nn.functional.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ output = output_tensor.transpose(1, 0)
+
+ output = dropout_add(output, residual, self.hidden_dropout, self.training)
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+ self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
+ self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
+ self.hidden_dropout = config.hidden_dropout
+ self.gelu_impl = BloomGelu()
+
+ def forward(self, hidden_states, residual):
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + nn.functional.linear(
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+
+ return output
+
+
+class BloomBlock(nn.Module):
+ def __init__(self, config, layer_number=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.n_head = config.n_head
+ self.self_attention = BloomAttention(config, layer_number=layer_number)
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = BloomMLP(config)
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states,
+ layer_past=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ alibi=None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class BloomPreTrainedModel(PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BloomConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BloomBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, BloomModel):
+ module.gradient_checkpointing = value
+
+
+BLOOM_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+BLOOM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
+ `past_key_values`).
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
+ BLOOM_START_DOCSTRING,
+)
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.n_head = config.n_head
+
+ # Embedding + LN Embedding
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Transformer blocks
+ self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.word_embeddings = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_head x N x N
+ # head_mask has shape n_layer x batch x n_head x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ current_sequence_length = hidden_states.shape[1]
+ if past_key_values[0] is not None:
+ current_sequence_length += past_key_values[0][0].shape[1]
+ alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions, alibi)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = hidden_states.view(output_shape)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForCausalLM(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = BloomModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ else:
+ position_ids = None
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+ """
+ return tuple(
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+ for layer_past in past
+ )
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a sequence classification head on top (linear layer).
+
+ [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForSequenceClassification(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = BloomModel(config)
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForTokenClassification(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.transformer = BloomModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/src/transformers/models/bloom/tokenization_bloom_fast.py b/src/transformers/models/bloom/tokenization_bloom_fast.py
new file mode 100644
index 00000000000000..c9785d641bbbcf
--- /dev/null
+++ b/src/transformers/models/bloom/tokenization_bloom_fast.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Bloom."""
+
+
+import json
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from tokenizers import pre_tokenizers
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from transformers.pipelines.conversational import Conversation
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "tokenizer_file": {
+ "bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
+ "bigscience/bloom-350m": "https://huggingface.co/bigscience/bloom-350m/blob/main/tokenizer.json",
+ "bigscience/bloom-760m": "https://huggingface.co/bigscience/bloom-760m/blob/main/tokenizer.json",
+ "bigscience/bloom-1b3": "https://huggingface.co/bigscience/bloom-1b3/blob/main/tokenizer.json",
+ "bigscience/bloom-2b5": "https://huggingface.co/bigscience/bloom-2b5/blob/main/tokenizer.json",
+ "bigscience/bloom-6b3": "https://huggingface.co/bigscience/bloom-2b5/blob/main/tokenizer.json",
+ "bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "bigscience/tokenizer": 1024,
+ "bigscience/bloom-350m": 1024,
+ "bigscience/bloom-760m": 1024,
+ "bigscience/bloom-1b3": 1024,
+ "bigscience/bloom-2b5": 1024,
+ "bigscience/bloom-6b3": 1024,
+ "bigscience/bloom": 1024,
+}
+
+
+class BloomTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import BloomTokenizerFast
+ >>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
+ >>> tokenizer("Hello world")['input_ids']
+ [15496, 995]
+ >>> tokenizer(" Hello world")['input_ids']
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (Bloom tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = None
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ add_prefix_space=False,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
+ """This corresponds to DialoGPT variants of models."""
+ input_ids = []
+ for is_user, text in conversation.iter_texts():
+ input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
+
+ if len(input_ids) > self.model_max_length:
+ input_ids = input_ids[-self.model_max_length :]
+ return input_ids
diff --git a/src/transformers/models/byt5/__init__.py b/src/transformers/models/byt5/__init__.py
index ec9a03212f438a..d7cffb390beb73 100644
--- a/src/transformers/models/byt5/__init__.py
+++ b/src/transformers/models/byt5/__init__.py
@@ -21,9 +21,7 @@
from ...utils import _LazyModule
-_import_structure = {
- "tokenization_byt5": ["ByT5Tokenizer"],
-}
+_import_structure = {"tokenization_byt5": ["ByT5Tokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py
index a0020301682293..7d9a20f3b0b395 100755
--- a/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py
@@ -49,8 +49,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained T5 model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/byt5/tokenization_byt5.py b/src/transformers/models/byt5/tokenization_byt5.py
index 77eb34f9295ed8..0071d7a9afe4ed 100644
--- a/src/transformers/models/byt5/tokenization_byt5.py
+++ b/src/transformers/models/byt5/tokenization_byt5.py
@@ -77,8 +77,9 @@ def __init__(
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to ByT5Tokenizer. "
- "In this case the additional_special_tokens must include the extra_ids tokens"
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to ByT5Tokenizer. In this case the additional_special_tokens must include the"
+ " extra_ids tokens"
)
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
@@ -146,7 +147,8 @@ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
- f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
+ " eos tokens being added."
)
return token_ids
else:
diff --git a/src/transformers/models/camembert/__init__.py b/src/transformers/models/camembert/__init__.py
index fccb4c49c9a777..c91683d1cde4d6 100644
--- a/src/transformers/models/camembert/__init__.py
+++ b/src/transformers/models/camembert/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
@@ -31,13 +32,28 @@
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_camembert"] = ["CamembertTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_camembert_fast"] = ["CamembertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_camembert"] = [
"CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"CamembertForCausalLM",
@@ -49,7 +65,12 @@
"CamembertModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_camembert"] = [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCamembertForCausalLM",
@@ -65,13 +86,28 @@
if TYPE_CHECKING:
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_camembert import CamembertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_camembert_fast import CamembertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
@@ -83,7 +119,12 @@
CamembertModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
diff --git a/src/transformers/models/camembert/configuration_camembert.py b/src/transformers/models/camembert/configuration_camembert.py
index 982afceb70beb7..6f872237327e4e 100644
--- a/src/transformers/models/camembert/configuration_camembert.py
+++ b/src/transformers/models/camembert/configuration_camembert.py
@@ -27,8 +27,12 @@
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"camembert-base": "https://huggingface.co/camembert-base/resolve/main/config.json",
- "umberto-commoncrawl-cased-v1": "https://huggingface.co/Musixmatch/umberto-commoncrawl-cased-v1/resolve/main/config.json",
- "umberto-wikipedia-uncased-v1": "https://huggingface.co/Musixmatch/umberto-wikipedia-uncased-v1/resolve/main/config.json",
+ "umberto-commoncrawl-cased-v1": (
+ "https://huggingface.co/Musixmatch/umberto-commoncrawl-cased-v1/resolve/main/config.json"
+ ),
+ "umberto-wikipedia-uncased-v1": (
+ "https://huggingface.co/Musixmatch/umberto-wikipedia-uncased-v1/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/canine/__init__.py b/src/transformers/models/canine/__init__.py
index 1d24a01b14b53e..307a819e128419 100644
--- a/src/transformers/models/canine/__init__.py
+++ b/src/transformers/models/canine/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -25,7 +25,12 @@
"tokenization_canine": ["CanineTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_canine"] = [
"CANINE_PRETRAINED_MODEL_ARCHIVE_LIST",
"CanineForMultipleChoice",
@@ -43,7 +48,12 @@
from .configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig
from .tokenization_canine import CanineTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_canine import (
CANINE_PRETRAINED_MODEL_ARCHIVE_LIST,
CanineForMultipleChoice,
diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py
index 2d903109ac0386..bb7b1492c7bf78 100644
--- a/src/transformers/models/canine/modeling_canine.py
+++ b/src/transformers/models/canine/modeling_canine.py
@@ -19,7 +19,7 @@
import math
import os
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -253,11 +253,11 @@ def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, n
def forward(
self,
- input_ids=None,
- token_type_ids=None,
- position_ids=None,
- inputs_embeds=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -356,7 +356,11 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, inputs, final_seq_char_positions=None):
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ final_seq_char_positions: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
# inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
# we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
inputs = torch.transpose(inputs, 1, 2)
@@ -419,12 +423,12 @@ def transpose_for_scores(self, x):
def forward(
self,
- from_tensor,
- to_tensor,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- ):
+ from_tensor: torch.Tensor,
+ to_tensor: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(from_tensor)
# If this is instantiated as a cross-attention module, the keys
@@ -496,7 +500,9 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(
+ self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -540,12 +546,11 @@ def __init__(
self.local = local
if attend_from_chunk_width < attend_from_chunk_stride:
raise ValueError(
- "`attend_from_chunk_width` < `attend_from_chunk_stride` "
- "would cause sequence positions to get skipped."
+ "`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped."
)
if attend_to_chunk_width < attend_to_chunk_stride:
raise ValueError(
- "`attend_to_chunk_width` < `attend_to_chunk_stride`" "would cause sequence positions to get skipped."
+ "`attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped."
)
self.always_attend_to_first_position = always_attend_to_first_position
self.first_position_attends_to_all = first_position_attends_to_all
@@ -574,11 +579,11 @@ def prune_heads(self, heads):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- ):
+ hidden_states: Tuple[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
if not self.local:
self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self_outputs[0]
@@ -656,7 +661,7 @@ def __init__(self, config):
else:
self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
@@ -669,7 +674,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -706,11 +711,11 @@ def __init__(
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- ):
+ hidden_states: Tuple[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
@@ -767,13 +772,13 @@ def __init__(
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: Tuple[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
@@ -822,7 +827,7 @@ def __init__(self, config):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
@@ -841,7 +846,7 @@ def __init__(self, config):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
@@ -862,7 +867,7 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
@@ -873,7 +878,10 @@ def __init__(self, config):
super().__init__()
self.predictions = CanineLMPredictionHead(config)
- def forward(self, sequence_output):
+ def forward(
+ self,
+ sequence_output: Tuple[torch.Tensor],
+ ) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
@@ -1093,16 +1101,16 @@ def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: torch.Tens
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CanineModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1130,12 +1138,12 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
molecule_attention_mask = self._downsample_attention_mask(
attention_mask, downsampling_rate=self.config.downsampling_rate
)
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device
+ molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
)
# Prepare head mask if needed
@@ -1275,17 +1283,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1372,17 +1380,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1465,17 +1473,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1543,18 +1551,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py
index 67e6841e6d0206..6a6e64c995d385 100644
--- a/src/transformers/models/clip/__init__.py
+++ b/src/transformers/models/clip/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
@@ -32,14 +33,29 @@
"tokenization_clip": ["CLIPTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_clip_fast"] = ["CLIPTokenizerFast"]
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_clip"] = ["CLIPFeatureExtractor"]
_import_structure["processing_clip"] = ["CLIPProcessor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_clip"] = [
"CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"CLIPModel",
@@ -48,7 +64,12 @@
"CLIPVisionModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_clip"] = [
"TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCLIPModel",
@@ -57,7 +78,12 @@
"TFCLIPVisionModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_clip"] = [
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
@@ -72,14 +98,29 @@
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from .tokenization_clip import CLIPTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_clip_fast import CLIPTokenizerFast
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_clip import CLIPFeatureExtractor
from .processing_clip import CLIPProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_clip import (
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
CLIPModel,
@@ -88,7 +129,12 @@
CLIPVisionModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_clip import (
TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCLIPModel,
@@ -97,7 +143,12 @@
TFCLIPVisionModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py
index 7614d05afd3e33..7f01b5e02b94df 100644
--- a/src/transformers/models/clip/feature_extraction_clip.py
+++ b/src/transformers/models/clip/feature_extraction_clip.py
@@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
+ convert_rgb (`bool`, defaults to `True`):
+ Whether or not to convert `PIL.Image.Image` into `RGB` format
"""
model_input_names = ["pixel_values"]
@@ -68,6 +70,7 @@ def __init__(
do_normalize=True,
image_mean=None,
image_std=None,
+ do_convert_rgb=True,
**kwargs
):
super().__init__(**kwargs)
@@ -79,6 +82,7 @@ def __init__(
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
+ self.do_convert_rgb = do_convert_rgb
def __call__(
self,
@@ -141,9 +145,14 @@ def __call__(
if not is_batched:
images = [images]
- # transformations (resizing + center cropping + normalization)
+ # transformations (convert rgb + resizing + center cropping + normalization)
+ if self.do_convert_rgb:
+ images = [self.convert_rgb(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ images = [
+ self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
+ for image in images
+ ]
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image, self.crop_size) for image in images]
if self.do_normalize:
@@ -154,56 +163,3 @@ def __call__(
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
-
- def center_crop(self, image, size):
- """
- Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
- size is given, it will be padded (so the returned result has the size asked).
-
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to resize.
- size (`int` or `Tuple[int, int]`):
- The size to which crop the image.
- """
- self._ensure_format_supported(image)
- if not isinstance(size, tuple):
- size = (size, size)
-
- if not isinstance(image, Image.Image):
- image = self.to_pil_image(image)
-
- image_width, image_height = image.size
- crop_height, crop_width = size
-
- crop_top = int((image_height - crop_height + 1) * 0.5)
- crop_left = int((image_width - crop_width + 1) * 0.5)
-
- return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
-
- def resize(self, image, size, resample=Image.BICUBIC):
- """
- Resizes `image`. Note that this will trigger a conversion of `image` to a PIL Image.
-
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to resize.
- size (`int` or `Tuple[int, int]`):
- The size to use for resizing the image. If `int` it will be resized to match the shorter side
- resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
- The filter to user for resampling.
- """
- self._ensure_format_supported(image)
-
- if not isinstance(image, Image.Image):
- image = self.to_pil_image(image)
- if isinstance(size, tuple):
- new_w, new_h = size
- else:
- width, height = image.size
- short, long = (width, height) if width <= height else (height, width)
- if short == size:
- return image
- new_short, new_long = size, int(size * long / short)
- new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short)
- return image.resize((new_w, new_h), resample)
diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py
index 44c340847ed1b3..5c34c658b59fc8 100755
--- a/src/transformers/models/clip/modeling_clip.py
+++ b/src/transformers/models/clip/modeling_clip.py
@@ -57,7 +57,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from
@@ -181,7 +181,8 @@ def __init__(self, config):
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
@@ -220,14 +221,16 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {causal_attention_mask.size()}"
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -258,7 +261,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -670,7 +674,7 @@ def _build_causal_attention_mask(self, bsz, seq_len):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
- mask.fill_(float("-inf"))
+ mask.fill_(torch.tensor(float("-inf")))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
@@ -848,12 +852,14 @@ def __init__(self, config: CLIPConfig):
if not isinstance(config.text_config, CLIPTextConfig):
raise ValueError(
- f"config.text_config is expected to be of type CLIPTextConfig but is of type {type(config.text_config)}."
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
+ f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, CLIPVisionConfig):
raise ValueError(
- f"config.vision_config is expected to be of type CLIPVisionConfig but is of type {type(config.vision_config)}."
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
+ f" {type(config.vision_config)}."
)
text_config = config.text_config
@@ -1036,8 +1042,8 @@ def forward(
text_embeds = self.text_projection(text_embeds)
# normalized features
- image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py
index 792c7b53253b3b..aa8ef87d5bf10b 100644
--- a/src/transformers/models/clip/modeling_flax_clip.py
+++ b/src/transformers/models/clip/modeling_flax_clip.py
@@ -262,7 +262,8 @@ def setup(self):
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = self.config.attention_dropout
diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py
index 5d209620015220..6ba83f04b8436b 100644
--- a/src/transformers/models/clip/modeling_tf_clip.py
+++ b/src/transformers/models/clip/modeling_tf_clip.py
@@ -266,7 +266,8 @@ def __init__(self, config: CLIPConfig, **kwargs):
self.attention_head_size = self.embed_dim // self.num_attention_heads
if self.attention_head_size * self.num_attention_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_attention_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_attention_heads})."
)
factor = config.initializer_factor
@@ -551,11 +552,14 @@ def call(
)
def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32):
-
- diag = tf.constant(0.0, shape=(seq_length,), dtype=dtype)
+ # It is possible with an unspecified sequence length for seq_length to be
+ # a runtime value, which is unsupported by tf.constant. Per the TensorFlow
+ # docs, tf.fill can handle runtime dynamic shapes:
+ # https://www.tensorflow.org/api_docs/python/tf/fill
+ diag = tf.cast(tf.fill((seq_length,), 0.0), dtype)
# set an additive 2D attention mask with all places being masked
- to_mask = tf.constant(-10000.0, shape=(seq_length, seq_length), dtype=dtype)
+ to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype)
# set diagonal & lower triangular parts to 0 (i.e. the places not to be masked)
# TIP: think the 2D matrix as the space of (query_seq, key_seq)
@@ -705,12 +709,14 @@ def __init__(self, config: CLIPConfig, **kwargs):
if not isinstance(config.text_config, CLIPTextConfig):
raise ValueError(
- f"config.text_config is expected to be of type CLIPTextConfig but is of type {type(config.text_config)}."
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
+ f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, CLIPVisionConfig):
raise ValueError(
- f"config.vision_config is expected to be of type CLIPVisionConfig but is of type {type(config.vision_config)}."
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
+ f" {type(config.vision_config)}."
)
self.config = config
@@ -1082,6 +1088,18 @@ def call(
return outputs
+ @tf.function(
+ input_signature=[
+ {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+ ]
+ )
+ def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
+ output = self.call(inputs)
+ return self.serving_output(output)
+
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -1123,7 +1141,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
}
]
)
- def serving(self, inputs):
+ def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.
@@ -1226,7 +1244,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
}
]
)
- def serving(self, inputs):
+ def serving(self, inputs: Dict[str, tf.Tensor]) -> TFCLIPOutput:
"""
Method used for serving the model.
@@ -1375,4 +1393,7 @@ def call(
return outputs
def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput:
+ # TODO: As is this currently fails with saved_model=True, because
+ # TensorFlow cannot trace through nested dataclasses. Reference:
+ # https://github.com/huggingface/transformers/pull/16886
return output
diff --git a/src/transformers/models/clip/tokenization_clip.py b/src/transformers/models/clip/tokenization_clip.py
index 81fb7159efdb54..c6870cc69f5526 100644
--- a/src/transformers/models/clip/tokenization_clip.py
+++ b/src/transformers/models/clip/tokenization_clip.py
@@ -345,7 +345,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/clip/tokenization_clip_fast.py b/src/transformers/models/clip/tokenization_clip_fast.py
index f6ff684c6b63a3..5fe6d3d445bb09 100644
--- a/src/transformers/models/clip/tokenization_clip_fast.py
+++ b/src/transformers/models/clip/tokenization_clip_fast.py
@@ -36,7 +36,9 @@
"openai/clip-vit-base-patch32": "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt",
},
"tokenizer_file": {
- "openai/clip-vit-base-patch32": "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/tokenizer.json",
+ "openai/clip-vit-base-patch32": (
+ "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/tokenizer.json"
+ ),
},
}
@@ -97,12 +99,12 @@ def __init__(
if not isinstance(self.backend_tokenizer.pre_tokenizer, pre_tokenizers.Sequence):
raise ValueError(
- "The `backend_tokenizer` provided does not match the expected format. The CLIP tokenizer has been "
- "heavily modified from transformers version 4.17.0. You need to convert the tokenizer you are using to be compatible with this version."
- "The easiest way to do so is "
- '`CLIPTokenizerFast.from_pretrained("path_to_local_folder_or_hub_repo, from_slow=True)`.'
- " If you want to use your existing tokenizer, you will have to revert to a version prior to "
- "4.17.0 of transformers."
+ "The `backend_tokenizer` provided does not match the expected format. The CLIP tokenizer has been"
+ " heavily modified from transformers version 4.17.0. You need to convert the tokenizer you are using"
+ " to be compatible with this version.The easiest way to do so is"
+ ' `CLIPTokenizerFast.from_pretrained("path_to_local_folder_or_hub_repo, from_slow=True)`. If you want'
+ " to use your existing tokenizer, you will have to revert to a version prior to 4.17.0 of"
+ " transformers."
)
self._wrap_decode_method_backend_tokenizer()
diff --git a/src/transformers/models/convbert/__init__.py b/src/transformers/models/convbert/__init__.py
index d4f44482e0bd80..1f8224f4b6489c 100644
--- a/src/transformers/models/convbert/__init__.py
+++ b/src/transformers/models/convbert/__init__.py
@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -25,10 +31,20 @@
"tokenization_convbert": ["ConvBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_convbert_fast"] = ["ConvBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_convbert"] = [
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvBertForMaskedLM",
@@ -43,7 +59,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_convbert"] = [
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFConvBertForMaskedLM",
@@ -61,10 +82,20 @@
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
from .tokenization_convbert import ConvBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_convbert_fast import ConvBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_convbert import (
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvBertForMaskedLM,
@@ -78,7 +109,12 @@
load_tf_weights_in_convbert,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_convbert import (
TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFConvBertForMaskedLM,
diff --git a/src/transformers/models/convbert/configuration_convbert.py b/src/transformers/models/convbert/configuration_convbert.py
index c424326b2b9b45..2b5bc42502db4a 100644
--- a/src/transformers/models/convbert/configuration_convbert.py
+++ b/src/transformers/models/convbert/configuration_convbert.py
@@ -26,7 +26,9 @@
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/config.json",
- "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/config.json",
+ "YituTech/conv-bert-medium-small": (
+ "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/config.json"
+ ),
"YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/config.json",
# See all ConvBERT models at https://huggingface.co/models?filter=convbert
}
@@ -37,9 +39,11 @@ class ConvBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`ConvBertModel`]. It is used to instantiate an
ConvBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the ConvBERT
- [conv-bert-base](https://huggingface.co/YituTech/conv-bert-base) architecture. Configuration objects inherit from
- [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
- for more information.
+ [YituTech/conv-bert-base](https://huggingface.co/YituTech/conv-bert-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
Args:
vocab_size (`int`, *optional*, defaults to 30522):
diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py
index cdea57cc24f236..3d4ff779874b30 100644
--- a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py
+++ b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py
@@ -45,8 +45,10 @@ def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_f
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained ConvBERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained ConvBERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py
index cf2240b79c01e0..9884d32aca7ec2 100755
--- a/src/transformers/models/convbert/modeling_convbert.py
+++ b/src/transformers/models/convbert/modeling_convbert.py
@@ -201,7 +201,13 @@ def __init__(self, config):
persistent=False,
)
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.LongTensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -245,7 +251,7 @@ class ConvBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "convbert"
supports_gradient_checkpointing = True
authorized_missing_keys = [r"position_ids"]
- authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"]
+ authorized_unexpected_keys = [r"convbert.embeddings_project.weight", r"convbert.embeddings_project.bias"]
def _init_weights(self, module):
"""Initialize the weights"""
@@ -287,7 +293,7 @@ def __init__(self, config, input_filters, output_filters, kernel_size, **kwargs)
self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.depthwise(hidden_states)
x = self.pointwise(x)
x += self.bias
@@ -341,12 +347,12 @@ def transpose_for_scores(self, x):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
batch_size = hidden_states.size(0)
# If this is instantiated as a cross-attention module, the keys
@@ -426,7 +432,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -460,12 +466,12 @@ def prune_heads(self, heads):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_outputs = self.self(
hidden_states,
attention_mask,
@@ -489,7 +495,7 @@ def __init__(self, input_size, output_size, num_groups):
self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim))
self.bias = nn.Parameter(torch.empty(output_size))
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = list(hidden_states.size())[0]
x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
x = x.permute(1, 0, 2)
@@ -514,7 +520,7 @@ def __init__(self, config):
else:
self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
@@ -532,7 +538,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -556,13 +562,13 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
@@ -575,7 +581,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
cross_attention_outputs = self.crossattention(
attention_output,
@@ -608,15 +615,15 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -684,7 +691,7 @@ def __init__(self, config):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
@@ -795,16 +802,16 @@ class PreTrainedModel
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -833,7 +840,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.embeddings(
@@ -864,7 +871,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
- def forward(self, generator_hidden_states):
+ def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(generator_hidden_states)
hidden_states = get_activation("gelu")(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
@@ -966,7 +973,7 @@ def __init__(self, config):
self.config = config
- def forward(self, hidden_states, **kwargs):
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
x = hidden_states[:, 0, :] # take token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
diff --git a/src/transformers/models/convbert/tokenization_convbert.py b/src/transformers/models/convbert/tokenization_convbert.py
index a49e32ec00bb8b..8bf1b2826e0aed 100644
--- a/src/transformers/models/convbert/tokenization_convbert.py
+++ b/src/transformers/models/convbert/tokenization_convbert.py
@@ -24,7 +24,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt",
- "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt",
+ "YituTech/conv-bert-medium-small": (
+ "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt"
+ ),
"YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt",
}
}
diff --git a/src/transformers/models/convbert/tokenization_convbert_fast.py b/src/transformers/models/convbert/tokenization_convbert_fast.py
index 525e369c4bd5b0..383382e13082b8 100644
--- a/src/transformers/models/convbert/tokenization_convbert_fast.py
+++ b/src/transformers/models/convbert/tokenization_convbert_fast.py
@@ -25,7 +25,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"YituTech/conv-bert-base": "https://huggingface.co/YituTech/conv-bert-base/resolve/main/vocab.txt",
- "YituTech/conv-bert-medium-small": "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt",
+ "YituTech/conv-bert-medium-small": (
+ "https://huggingface.co/YituTech/conv-bert-medium-small/resolve/main/vocab.txt"
+ ),
"YituTech/conv-bert-small": "https://huggingface.co/YituTech/conv-bert-small/resolve/main/vocab.txt",
}
}
diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py
index b488da389f8c5c..93000d5c66c8ce 100644
--- a/src/transformers/models/convnext/__init__.py
+++ b/src/transformers/models/convnext/__init__.py
@@ -18,17 +18,33 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tf_available, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_torch_available,
+ is_vision_available,
+)
_import_structure = {
- "configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"],
+ "configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig", "ConvNextOnnxConfig"]
}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_convnext"] = ["ConvNextFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_convnext"] = [
"CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvNextForImageClassification",
@@ -36,7 +52,12 @@
"ConvNextPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_convnext"] = [
"TFConvNextForImageClassification",
"TFConvNextModel",
@@ -44,12 +65,22 @@
]
if TYPE_CHECKING:
- from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
+ from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig, ConvNextOnnxConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_convnext import ConvNextFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_convnext import (
CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvNextForImageClassification,
@@ -57,7 +88,12 @@
ConvNextPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py
index 74067ad337bbfc..9f77c0099299ca 100644
--- a/src/transformers/models/convnext/configuration_convnext.py
+++ b/src/transformers/models/convnext/configuration_convnext.py
@@ -14,7 +14,13 @@
# limitations under the License.
""" ConvNeXT model configuration"""
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -101,3 +107,20 @@ def __init__(
self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate
self.image_size = image_size
+
+
+class ConvNextOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-5
diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py
index a0ca01d4c8ac9e..fc484627218a2a 100755
--- a/src/transformers/models/convnext/modeling_convnext.py
+++ b/src/transformers/models/convnext/modeling_convnext.py
@@ -15,6 +15,8 @@
""" PyTorch ConvNext model."""
+from typing import Optional, Tuple, Union
+
import torch
import torch.utils.checkpoint
from torch import nn
@@ -78,7 +80,7 @@ def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
@@ -98,7 +100,7 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
@@ -121,7 +123,7 @@ def __init__(self, config):
)
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
- def forward(self, pixel_values):
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings)
return embeddings
@@ -155,7 +157,7 @@ def __init__(self, config, dim, drop_path=0):
)
self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
input = hidden_states
x = self.dwconv(hidden_states)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
@@ -197,7 +199,7 @@ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, d
*[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
hidden_states = self.downsampling_layer(hidden_states)
hidden_states = self.layers(hidden_states)
return hidden_states
@@ -207,8 +209,9 @@ class ConvNextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.stages = nn.ModuleList()
- drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
- cur = 0
+ drop_path_rates = [
+ x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
+ ]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):
out_chs = config.hidden_sizes[i]
@@ -218,13 +221,17 @@ def __init__(self, config):
out_channels=out_chs,
stride=2 if i > 0 else 1,
depth=config.depths[i],
- drop_path_rates=drop_path_rates[cur],
+ drop_path_rates=drop_path_rates[i],
)
self.stages.append(stage)
- cur += config.depths[i]
prev_chs = out_chs
- def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.stages):
@@ -325,7 +332,12 @@ def __init__(self, config):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
- def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
@@ -387,7 +399,13 @@ def __init__(self, config):
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
- def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py
index 1cb1b71b6130cc..3446925a072c89 100644
--- a/src/transformers/models/convnext/modeling_tf_convnext.py
+++ b/src/transformers/models/convnext/modeling_tf_convnext.py
@@ -235,8 +235,9 @@ class TFConvNextEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.stages = []
- drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
- cur = 0
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+ drop_path_rates = tf.split(drop_path_rates, config.depths)
+ drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):
out_chs = config.hidden_sizes[i]
@@ -246,11 +247,10 @@ def __init__(self, config, **kwargs):
out_channels=out_chs,
stride=2 if i > 0 else 1,
depth=config.depths[i],
- drop_path_rates=drop_path_rates[cur],
+ drop_path_rates=drop_path_rates[i],
name=f"stages.{i}",
)
self.stages.append(stage)
- cur += config.depths[i]
prev_chs = out_chs
def call(self, hidden_states, output_hidden_states=False, return_dict=True):
@@ -537,7 +537,7 @@ def call(
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224")
- >>> model = TFViTForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
diff --git a/src/transformers/models/cpm/__init__.py b/src/transformers/models/cpm/__init__.py
index 06f8a912617ec2..01f45c436b11be 100644
--- a/src/transformers/models/cpm/__init__.py
+++ b/src/transformers/models/cpm/__init__.py
@@ -18,23 +18,43 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_cpm"] = ["CpmTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_cpm_fast"] = ["CpmTokenizerFast"]
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_cpm import CpmTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_cpm_fast import CpmTokenizerFast
else:
diff --git a/src/transformers/models/ctrl/__init__.py b/src/transformers/models/ctrl/__init__.py
index efd342331fe8f5..53200a1b031979 100644
--- a/src/transformers/models/ctrl/__init__.py
+++ b/src/transformers/models/ctrl/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_ctrl": ["CTRLTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_ctrl"] = [
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
"CTRLForSequenceClassification",
@@ -35,7 +40,12 @@
"CTRLPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_ctrl"] = [
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCTRLForSequenceClassification",
@@ -49,7 +59,12 @@
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .tokenization_ctrl import CTRLTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification,
@@ -58,7 +73,12 @@
CTRLPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLForSequenceClassification,
diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py
index 291e12002fde3e..cec2d0d345b235 100644
--- a/src/transformers/models/ctrl/modeling_ctrl.py
+++ b/src/transformers/models/ctrl/modeling_ctrl.py
@@ -784,7 +784,7 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py
index 7fadc65cfff44e..cdbed79135101a 100644
--- a/src/transformers/models/ctrl/modeling_tf_ctrl.py
+++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py
@@ -807,7 +807,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/ctrl/tokenization_ctrl.py b/src/transformers/models/ctrl/tokenization_ctrl.py
index c44b1d329f7e91..f8524bdf1f54ac 100644
--- a/src/transformers/models/ctrl/tokenization_ctrl.py
+++ b/src/transformers/models/ctrl/tokenization_ctrl.py
@@ -236,7 +236,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py
new file mode 100644
index 00000000000000..36a6f69824eff6
--- /dev/null
+++ b/src/transformers/models/cvt/__init__.py
@@ -0,0 +1,59 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {"configuration_cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"]}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_cvt"] = [
+ "CVT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "CvtForImageClassification",
+ "CvtModel",
+ "CvtPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_cvt import (
+ CVT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ CvtForImageClassification,
+ CvtModel,
+ CvtPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/cvt/configuration_cvt.py b/src/transformers/models/cvt/configuration_cvt.py
new file mode 100644
index 00000000000000..e1e633e73b57b1
--- /dev/null
+++ b/src/transformers/models/cvt/configuration_cvt.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" CvT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+CVT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/cvt-13": "https://huggingface.co/microsoft/cvt-13/resolve/main/config.json",
+ # See all Cvt models at https://huggingface.co/models?filter=cvt
+}
+
+
+class CvtConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the CvT
+ [microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`):
+ The kernel size of each encoder's patch embedding.
+ patch_stride (`List[int]`, *optional*, defaults to `[4, 2, 2]`):
+ The stride size of each encoder's patch embedding.
+ patch_padding (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
+ The padding size of each encoder's patch embedding.
+ embed_dim (`List[int]`, *optional*, defaults to `[64, 192, 384]`):
+ Dimension of each of the encoder blocks.
+ num_heads (`List[int]`, *optional*, defaults to `[1, 3, 6]`):
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
+ depth (`List[int]`, *optional*, defaults to `[1, 2, 10]`):
+ The number of layers in each encoder block.
+ mlp_ratios (`List[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`):
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
+ encoder blocks.
+ attention_drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
+ The dropout ratio for the attention probabilities.
+ drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
+ The dropout ratio for the patch embeddings probabilities.
+ drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`):
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
+ qkv_bias (`List[bool]`, *optional*, defaults to `[True, True, True]`):
+ The bias bool for query, key and value in attentions
+ cls_token (`List[bool]`, *optional*, defaults to `[False, False, True]`):
+ Whether or not to add a classification token to the output of each of the last 3 stages.
+ qkv_projection_method (`List[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`):
+ The projection method for query, key and value Default is depth-wise convolutions with batch norm. For
+ Linear projection use "avg".
+ kernel_qkv (`List[int]`, *optional*, defaults to `[3, 3, 3]`):
+ The kernel size for query, key and value in attention layer
+ padding_kv (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
+ The padding size for key and value in attention layer
+ stride_kv (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ The stride size for key and value in attention layer
+ padding_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
+ The padding size for query in attention layer
+ stride_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
+ The stride size for query in attention layer
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import CvtModel, CvtConfig
+
+ >>> # Initializing a Cvt msft/cvt style configuration
+ >>> configuration = CvtConfig()
+
+ >>> # Initializing a model from the msft/cvt style configuration
+ >>> model = CvtModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "cvt"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_sizes=[7, 3, 3],
+ patch_stride=[4, 2, 2],
+ patch_padding=[2, 1, 1],
+ embed_dim=[64, 192, 384],
+ num_heads=[1, 3, 6],
+ depth=[1, 2, 10],
+ mlp_ratio=[4.0, 4.0, 4.0],
+ attention_drop_rate=[0.0, 0.0, 0.0],
+ drop_rate=[0.0, 0.0, 0.0],
+ drop_path_rate=[0.0, 0.0, 0.1],
+ qkv_bias=[True, True, True],
+ cls_token=[False, False, True],
+ qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"],
+ kernel_qkv=[3, 3, 3],
+ padding_kv=[1, 1, 1],
+ stride_kv=[2, 2, 2],
+ padding_q=[1, 1, 1],
+ stride_q=[1, 1, 1],
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_channels = num_channels
+ self.patch_sizes = patch_sizes
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.depth = depth
+ self.mlp_ratio = mlp_ratio
+ self.attention_drop_rate = attention_drop_rate
+ self.drop_rate = drop_rate
+ self.drop_path_rate = drop_path_rate
+ self.qkv_bias = qkv_bias
+ self.cls_token = cls_token
+ self.qkv_projection_method = qkv_projection_method
+ self.kernel_qkv = kernel_qkv
+ self.padding_kv = padding_kv
+ self.stride_kv = stride_kv
+ self.padding_q = padding_q
+ self.stride_q = stride_q
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
diff --git a/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 00000000000000..a33487c9e62a3c
--- /dev/null
+++ b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,362 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert CvT checkpoints from the original repository.
+
+URL: https://github.com/microsoft/CvT"""
+
+
+import argparse
+import json
+from collections import OrderedDict
+
+import torch
+
+from huggingface_hub import cached_download, hf_hub_url
+from transformers import AutoFeatureExtractor, CvtConfig, CvtForImageClassification
+
+
+def embeddings(idx):
+ """
+ The function helps in renaming embedding layer weights.
+
+ Args:
+ idx: stage number in original model
+ """
+ embed = []
+ embed.append(
+ (
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
+ f"stage{idx}.patch_embed.proj.weight",
+ )
+ )
+ embed.append(
+ (
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
+ f"stage{idx}.patch_embed.proj.bias",
+ )
+ )
+ embed.append(
+ (
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
+ f"stage{idx}.patch_embed.norm.weight",
+ )
+ )
+ embed.append(
+ (
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
+ f"stage{idx}.patch_embed.norm.bias",
+ )
+ )
+ return embed
+
+
+def attention(idx, cnt):
+ """
+ The function helps in renaming attention block layers weights.
+
+ Args:
+ idx: stage number in original model
+ cnt: count of blocks in each stage
+ """
+ attention_weights = []
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
+ f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
+ f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
+ f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
+ f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
+ f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
+ f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
+ f"stage{idx}.blocks.{cnt}.attn.proj.weight",
+ )
+ )
+ attention_weights.append(
+ (
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
+ f"stage{idx}.blocks.{cnt}.attn.proj.bias",
+ )
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
+ )
+ attention_weights.append(
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
+ )
+ return attention_weights
+
+
+def cls_token(idx):
+ """
+ Function helps in renaming cls_token weights
+ """
+ token = []
+ token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
+ return token
+
+
+def final():
+ """
+ Function helps in renaming final classification layer
+ """
+ head = []
+ head.append(("layernorm.weight", "norm.weight"))
+ head.append(("layernorm.bias", "norm.bias"))
+ head.append(("classifier.weight", "head.weight"))
+ head.append(("classifier.bias", "head.bias"))
+ return head
+
+
+def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
+ """
+ Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
+ """
+ img_labels_file = "imagenet-1k-id2label.json"
+ num_labels = 1000
+
+ repo_id = "datasets/huggingface/label-files"
+ num_labels = num_labels
+ id2label = json.load(open(cached_download(hf_hub_url(repo_id, img_labels_file)), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+
+ id2label = id2label
+ label2id = {v: k for k, v in id2label.items()}
+
+ config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
+
+ # For depth size 13 (13 = 1+2+10)
+ if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
+ config.depth = [1, 2, 10]
+
+ # For depth size 21 (21 = 1+4+16)
+ elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
+ config.depth = [1, 4, 16]
+
+ # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
+ else:
+ config.depth = [2, 2, 20]
+ config.num_heads = [3, 12, 16]
+ config.embed_dim = [192, 768, 1024]
+
+ model = CvtForImageClassification(config)
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
+ feature_extractor.size = image_size
+ original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
+
+ huggingface_weights = OrderedDict()
+ list_of_state_dict = []
+
+ for idx in range(len(config.depth)):
+ if config.cls_token[idx]:
+ list_of_state_dict = list_of_state_dict + cls_token(idx)
+ list_of_state_dict = list_of_state_dict + embeddings(idx)
+ for cnt in range(config.depth[idx]):
+ list_of_state_dict = list_of_state_dict + attention(idx, cnt)
+
+ list_of_state_dict = list_of_state_dict + final()
+ for gg in list_of_state_dict:
+ print(gg)
+ for i in range(len(list_of_state_dict)):
+ huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
+
+ model.load_state_dict(huggingface_weights)
+ model.save_pretrained(pytorch_dump_folder)
+ feature_extractor.save_pretrained(pytorch_dump_folder)
+
+
+# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--cvt_model",
+ default="cvt-w24",
+ type=str,
+ help="Name of the cvt model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--image_size",
+ default=384,
+ type=int,
+ help="Input Image Size",
+ )
+ parser.add_argument(
+ "--cvt_file_name",
+ default="cvtmodels\CvT-w24-384x384-IN-22k.pth",
+ type=str,
+ help="Input Image Size",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+
+ args = parser.parse_args()
+ convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)
diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py
new file mode 100644
index 00000000000000..ca6d3bd0b31411
--- /dev/null
+++ b/src/transformers/models/cvt/modeling_cvt.py
@@ -0,0 +1,723 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch CvT model."""
+
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import logging
+from .configuration_cvt import CvtConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "CvtConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/cvt-13"
+_EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+CVT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/cvt-13",
+ "microsoft/cvt-13-384",
+ "microsoft/cvt-13-384-22k",
+ "microsoft/cvt-21",
+ "microsoft/cvt-21-384",
+ "microsoft/cvt-21-384-22k",
+ # See all Cvt models at https://huggingface.co/models?filter=cvt
+]
+
+
+@dataclass
+class BaseModelOutputWithCLSToken(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
+ Classification token at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ cls_token_value: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.convnext.modeling_convnext.drop_path
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
+ DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
+ Connect' is a different form of dropout in a separate paper... See discussion:
+ https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
+ argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath
+class CvtDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class CvtEmbeddings(nn.Module):
+ """
+ Construct the CvT embeddings.
+ """
+
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
+ super().__init__()
+ self.convolution_embeddings = CvtConvEmbeddings(
+ patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
+ )
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, pixel_values):
+ hidden_state = self.convolution_embeddings(pixel_values)
+ hidden_state = self.dropout(hidden_state)
+ return hidden_state
+
+
+class CvtConvEmbeddings(nn.Module):
+ """
+ Image to Conv Embedding.
+ """
+
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
+ super().__init__()
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ self.patch_size = patch_size
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
+ self.normalization = nn.LayerNorm(embed_dim)
+
+ def forward(self, pixel_values):
+ pixel_values = self.projection(pixel_values)
+ batch_size, num_channels, height, width = pixel_values.shape
+ hidden_size = height * width
+ # rearrange "b c h w -> b (h w) c"
+ pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
+ if self.normalization:
+ pixel_values = self.normalization(pixel_values)
+ # rearrange "b (h w) c" -> b c h w"
+ pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ return pixel_values
+
+
+class CvtSelfAttentionConvProjection(nn.Module):
+ def __init__(self, embed_dim, kernel_size, padding, stride):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ embed_dim,
+ embed_dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=embed_dim,
+ )
+ self.normalization = nn.BatchNorm2d(embed_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.convolution(hidden_state)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
+
+class CvtSelfAttentionLinearProjection(nn.Module):
+ def forward(self, hidden_state):
+ batch_size, num_channels, height, width = hidden_state.shape
+ hidden_size = height * width
+ # rearrange " b c h w -> b (h w) c"
+ hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
+ return hidden_state
+
+
+class CvtSelfAttentionProjection(nn.Module):
+ def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"):
+ super().__init__()
+ if projection_method == "dw_bn":
+ self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride)
+ self.linear_projection = CvtSelfAttentionLinearProjection()
+
+ def forward(self, hidden_state):
+ hidden_state = self.convolution_projection(hidden_state)
+ hidden_state = self.linear_projection(hidden_state)
+ return hidden_state
+
+
+class CvtSelfAttention(nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ embed_dim,
+ kernel_size,
+ padding_q,
+ padding_kv,
+ stride_q,
+ stride_kv,
+ qkv_projection_method,
+ qkv_bias,
+ attention_drop_rate,
+ with_cls_token=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.scale = embed_dim**-0.5
+ self.with_cls_token = with_cls_token
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+
+ self.convolution_projection_query = CvtSelfAttentionProjection(
+ embed_dim,
+ kernel_size,
+ padding_q,
+ stride_q,
+ projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
+ )
+ self.convolution_projection_key = CvtSelfAttentionProjection(
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
+ )
+ self.convolution_projection_value = CvtSelfAttentionProjection(
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
+ )
+
+ self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
+ self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
+ self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
+
+ self.dropout = nn.Dropout(attention_drop_rate)
+
+ def rearrange_for_multi_head_attention(self, hidden_state):
+ batch_size, hidden_size, _ = hidden_state.shape
+ head_dim = self.embed_dim // self.num_heads
+ # rearrange 'b t (h d) -> b h t d'
+ return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
+
+ def forward(self, hidden_state, height, width):
+ if self.with_cls_token:
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
+ batch_size, hidden_size, num_channels = hidden_state.shape
+ # rearrange "b (h w) c -> b c h w"
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
+
+ key = self.convolution_projection_key(hidden_state)
+ query = self.convolution_projection_query(hidden_state)
+ value = self.convolution_projection_value(hidden_state)
+
+ if self.with_cls_token:
+ query = torch.cat((cls_token, query), dim=1)
+ key = torch.cat((cls_token, key), dim=1)
+ value = torch.cat((cls_token, value), dim=1)
+
+ head_dim = self.embed_dim // self.num_heads
+
+ query = self.rearrange_for_multi_head_attention(self.projection_query(query))
+ key = self.rearrange_for_multi_head_attention(self.projection_key(key))
+ value = self.rearrange_for_multi_head_attention(self.projection_value(value))
+
+ attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
+ attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
+ attention_probs = self.dropout(attention_probs)
+
+ context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
+ # rearrange"b h t d -> b t (h d)"
+ _, _, hidden_size, _ = context.shape
+ context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim)
+ return context
+
+
+class CvtSelfOutput(nn.Module):
+ """
+ The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, embed_dim, drop_rate):
+ super().__init__()
+ self.dense = nn.Linear(embed_dim, embed_dim)
+ self.dropout = nn.Dropout(drop_rate)
+
+ def forward(self, hidden_state, input_tensor):
+ hidden_state = self.dense(hidden_state)
+ hidden_state = self.dropout(hidden_state)
+ return hidden_state
+
+
+class CvtAttention(nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ embed_dim,
+ kernel_size,
+ padding_q,
+ padding_kv,
+ stride_q,
+ stride_kv,
+ qkv_projection_method,
+ qkv_bias,
+ attention_drop_rate,
+ drop_rate,
+ with_cls_token=True,
+ ):
+ super().__init__()
+ self.attention = CvtSelfAttention(
+ num_heads,
+ embed_dim,
+ kernel_size,
+ padding_q,
+ padding_kv,
+ stride_q,
+ stride_kv,
+ qkv_projection_method,
+ qkv_bias,
+ attention_drop_rate,
+ with_cls_token,
+ )
+ self.output = CvtSelfOutput(embed_dim, drop_rate)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_state, height, width):
+ self_output = self.attention(hidden_state, height, width)
+ attention_output = self.output(self_output, hidden_state)
+ return attention_output
+
+
+class CvtIntermediate(nn.Module):
+ def __init__(self, embed_dim, mlp_ratio):
+ super().__init__()
+ self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
+ self.activation = nn.GELU()
+
+ def forward(self, hidden_state):
+ hidden_state = self.dense(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
+
+
+class CvtOutput(nn.Module):
+ def __init__(self, embed_dim, mlp_ratio, drop_rate):
+ super().__init__()
+ self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
+ self.dropout = nn.Dropout(drop_rate)
+
+ def forward(self, hidden_state, input_tensor):
+ hidden_state = self.dense(hidden_state)
+ hidden_state = self.dropout(hidden_state)
+ hidden_state = hidden_state + input_tensor
+ return hidden_state
+
+
+class CvtLayer(nn.Module):
+ """
+ CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
+ """
+
+ def __init__(
+ self,
+ num_heads,
+ embed_dim,
+ kernel_size,
+ padding_q,
+ padding_kv,
+ stride_q,
+ stride_kv,
+ qkv_projection_method,
+ qkv_bias,
+ attention_drop_rate,
+ drop_rate,
+ mlp_ratio,
+ drop_path_rate,
+ with_cls_token=True,
+ ):
+ super().__init__()
+ self.attention = CvtAttention(
+ num_heads,
+ embed_dim,
+ kernel_size,
+ padding_q,
+ padding_kv,
+ stride_q,
+ stride_kv,
+ qkv_projection_method,
+ qkv_bias,
+ attention_drop_rate,
+ drop_rate,
+ with_cls_token,
+ )
+
+ self.intermediate = CvtIntermediate(embed_dim, mlp_ratio)
+ self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate)
+ self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_before = nn.LayerNorm(embed_dim)
+ self.layernorm_after = nn.LayerNorm(embed_dim)
+
+ def forward(self, hidden_state, height, width):
+ self_attention_output = self.attention(
+ self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention
+ height,
+ width,
+ )
+ attention_output = self_attention_output
+ attention_output = self.drop_path(attention_output)
+
+ # first residual connection
+ hidden_state = attention_output + hidden_state
+
+ # in Cvt, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_state)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_state)
+ layer_output = self.drop_path(layer_output)
+ return layer_output
+
+
+class CvtStage(nn.Module):
+ def __init__(self, config, stage):
+ super().__init__()
+ self.config = config
+ self.stage = stage
+ if self.config.cls_token[self.stage]:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.embed_dim[-1]))
+
+ self.embedding = CvtEmbeddings(
+ patch_size=config.patch_sizes[self.stage],
+ stride=config.patch_stride[self.stage],
+ num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
+ embed_dim=config.embed_dim[self.stage],
+ padding=config.patch_padding[self.stage],
+ dropout_rate=config.drop_rate[self.stage],
+ )
+
+ drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]
+
+ self.layers = nn.Sequential(
+ *[
+ CvtLayer(
+ num_heads=config.num_heads[self.stage],
+ embed_dim=config.embed_dim[self.stage],
+ kernel_size=config.kernel_qkv[self.stage],
+ padding_q=config.padding_q[self.stage],
+ padding_kv=config.padding_kv[self.stage],
+ stride_kv=config.stride_kv[self.stage],
+ stride_q=config.stride_q[self.stage],
+ qkv_projection_method=config.qkv_projection_method[self.stage],
+ qkv_bias=config.qkv_bias[self.stage],
+ attention_drop_rate=config.attention_drop_rate[self.stage],
+ drop_rate=config.drop_rate[self.stage],
+ drop_path_rate=drop_path_rates[self.stage],
+ mlp_ratio=config.mlp_ratio[self.stage],
+ with_cls_token=config.cls_token[self.stage],
+ )
+ for _ in range(config.depth[self.stage])
+ ]
+ )
+
+ def forward(self, hidden_state):
+ cls_token = None
+ hidden_state = self.embedding(hidden_state)
+ batch_size, num_channels, height, width = hidden_state.shape
+ # rearrange b c h w -> b (h w) c"
+ hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ if self.config.cls_token[self.stage]:
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
+ hidden_state = torch.cat((cls_token, hidden_state), dim=1)
+
+ for layer in self.layers:
+ layer_outputs = layer(hidden_state, height, width)
+ hidden_state = layer_outputs
+
+ if self.config.cls_token[self.stage]:
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ return hidden_state, cls_token
+
+
+class CvtEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.stages = nn.ModuleList([])
+ for stage_idx in range(len(config.depth)):
+ self.stages.append(CvtStage(config, stage_idx))
+
+ def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+ hidden_state = pixel_values
+
+ cls_token = None
+ for _, (stage_module) in enumerate(self.stages):
+ hidden_state, cls_token = stage_module(hidden_state)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithCLSToken(
+ last_hidden_state=hidden_state,
+ cls_token_value=cls_token,
+ hidden_states=all_hidden_states,
+ )
+
+
+class CvtPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = CvtConfig
+ base_model_prefix = "cvt"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+CVT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CVT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`CvtFeatureExtractor`]. See
+ [`CvtFeatureExtractor.__call__`] for details.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
+ CVT_START_DOCSTRING,
+)
+class CvtModel(CvtPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+ self.encoder = CvtEncoder(config)
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithCLSToken,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ encoder_outputs = self.encoder(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutputWithCLSToken(
+ last_hidden_state=sequence_output,
+ cls_token_value=encoder_outputs.cls_token_value,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ CVT_START_DOCSTRING,
+)
+class CvtForImageClassification(CvtPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.cvt = CvtModel(config, add_pooling_layer=False)
+ self.layernorm = nn.LayerNorm(config.embed_dim[-1])
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values=None,
+ labels=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.cvt(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ cls_token = outputs[1]
+ if self.config.cls_token[-1]:
+ sequence_output = self.layernorm(cls_token)
+ else:
+ batch_size, num_channels, height, width = sequence_output.shape
+ # rearrange "b c h w -> b (h w) c"
+ sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ sequence_output = self.layernorm(sequence_output)
+
+ sequence_output_mean = sequence_output.mean(dim=1)
+ logits = self.classifier(sequence_output_mean)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.config.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.config.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
diff --git a/src/transformers/models/data2vec/__init__.py b/src/transformers/models/data2vec/__init__.py
index a1296fd334e365..794124575e1316 100644
--- a/src/transformers/models/data2vec/__init__.py
+++ b/src/transformers/models/data2vec/__init__.py
@@ -18,14 +18,11 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
- "configuration_data2vec_audio": [
- "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "Data2VecAudioConfig",
- ],
+ "configuration_data2vec_audio": ["DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP", "Data2VecAudioConfig"],
"configuration_data2vec_text": [
"DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Data2VecTextConfig",
@@ -38,7 +35,12 @@
],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_data2vec_audio"] = [
"DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST",
"Data2VecAudioForAudioFrameClassification",
@@ -68,6 +70,14 @@
"Data2VecVisionPreTrainedModel",
]
+if is_tf_available():
+ _import_structure["modeling_tf_data2vec_vision"] = [
+ "TFData2VecVisionForImageClassification",
+ "TFData2VecVisionForSemanticSegmentation",
+ "TFData2VecVisionModel",
+ "TFData2VecVisionPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
from .configuration_data2vec_audio import DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig
from .configuration_data2vec_text import (
@@ -81,7 +91,12 @@
Data2VecVisionOnnxConfig,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_data2vec_audio import (
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST,
Data2VecAudioForAudioFrameClassification,
@@ -110,6 +125,13 @@
Data2VecVisionModel,
Data2VecVisionPreTrainedModel,
)
+ if is_tf_available():
+ from .modeling_tf_data2vec_vision import (
+ TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
+ TFData2VecVisionModel,
+ TFData2VecVisionPreTrainedModel,
+ )
else:
import sys
diff --git a/src/transformers/models/data2vec/configuration_data2vec_audio.py b/src/transformers/models/data2vec/configuration_data2vec_audio.py
index 71d455702e6396..cc32f2cc698972 100644
--- a/src/transformers/models/data2vec/configuration_data2vec_audio.py
+++ b/src/transformers/models/data2vec/configuration_data2vec_audio.py
@@ -71,13 +71,13 @@ class Data2VecAudioConfig(PretrainedConfig):
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
@@ -124,13 +124,13 @@ class Data2VecAudioConfig(PretrainedConfig):
instance of [`Data2VecAudioForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -245,10 +245,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/data2vec/configuration_data2vec_vision.py b/src/transformers/models/data2vec/configuration_data2vec_vision.py
index 5508f4d9e7e779..a7dd85b817348a 100644
--- a/src/transformers/models/data2vec/configuration_data2vec_vision.py
+++ b/src/transformers/models/data2vec/configuration_data2vec_vision.py
@@ -26,7 +26,9 @@
logger = logging.get_logger(__name__)
DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/data2vec-vision-base-ft": "https://huggingface.co/facebook/data2vec-vision-base-ft/resolve/main/config.json",
+ "facebook/data2vec-vision-base-ft": (
+ "https://huggingface.co/facebook/data2vec-vision-base-ft/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
index e8a703de91f367..01c2d8cab27894 100644
--- a/src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
@@ -66,7 +66,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
if hf_shape != value.shape:
raise ValueError(
- f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
)
if weight_type == "weight":
diff --git a/src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
index 8659e36d9f4838..9a38b3ae0bd1a3 100644
--- a/src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
@@ -98,13 +98,22 @@ def convert_data2vec_checkpoint_to_pytorch(
self_attn: BertSelfAttention = layer.attention.self
assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
- ), f"Shape for data2vec_layer.self_attn.k_proj.weight.data should be {torch.Size((config.hidden_size, config.hidden_size))}"
+ ), (
+ "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
+ )
assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
- ), f"Shape for data2vec_layer.self_attn.q_proj.weight.data should be {torch.Size((config.hidden_size, config.hidden_size))}"
+ ), (
+ "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
+ )
assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
(config.hidden_size, config.hidden_size)
- ), f"Shape for data2vec_layer.self_attn.v_proj.weight.data should be {torch.Size((config.hidden_size, config.hidden_size))}"
+ ), (
+ "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
+ )
self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index 3f255248c1bae2..1ee847f4ea1a70 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -16,7 +16,6 @@
import math
import warnings
-from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
@@ -27,16 +26,17 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
-from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
-)
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_data2vec_audio import Data2VecAudioConfig
@@ -81,69 +81,6 @@
]
-@dataclass
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput with Wav2Vec2->Data2VecAudio
-class Data2VecAudioBaseModelOutput(ModelOutput):
- """
- Output type of [`Data2VecAudioBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
-@dataclass
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.XVectorOutput with Wav2Vec2->Data2VecAudio
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Data2VecAudioForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@@ -498,7 +435,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -514,7 +452,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -535,7 +474,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -634,7 +574,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -970,7 +911,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Data2VecAudioBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -983,7 +924,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Data2VecAudioBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1020,7 +961,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return Data2VecAudioBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1294,13 +1235,15 @@ def __init__(self, config):
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
- "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
+ "Audio frame classification does not support the use of Data2VecAudio adapters"
+ " (config.add_adapter=True)"
)
self.data2vec_audio = Data2VecAudioModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1344,6 +1287,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1376,12 +1320,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py
index 9168281eb8447a..9c85d346174aaf 100644
--- a/src/transformers/models/data2vec/modeling_data2vec_text.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_text.py
@@ -182,7 +182,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -426,7 +426,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -820,7 +821,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -908,21 +909,21 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1069,19 +1070,19 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1183,17 +1184,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1282,17 +1283,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- token_type_ids=None,
- attention_mask=None,
- labels=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1380,17 +1381,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1484,18 +1485,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
new file mode 100644
index 00000000000000..e7cc7d2449e75b
--- /dev/null
+++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py
@@ -0,0 +1,1452 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 Data2Vec Vision model."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from transformers.tf_utils import shape_list, stable_softmax
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFSemanticSegmenterOutput,
+ TFSequenceClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_data2vec_vision import Data2VecVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Data2VecVisionConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
+
+TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/data2vec-vision-base-ft1k",
+ # See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision
+]
+
+
+@dataclass
+class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
+ """
+ Class for outputs of [`TFData2VecVisionModel`].
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
+ will be returned.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: tf.Tensor = None
+ pooler_output: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+
+
+class TFDropPath(tf.keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.patch_embeddings = TFPatchEmbeddings(
+ config=config, image_size=config.image_size, patch_size=config.patch_size, name="patch_embeddings"
+ )
+ self.num_patches = self.patch_embeddings.num_patches
+ self.config = config
+
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
+
+ def build(self, input_shape: tf.TensorShape):
+ self.cls_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="cls_token",
+ )
+ if self.config.use_mask_token:
+ self.mask_token = self.add_weight(
+ shape=(1, 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="mask_token",
+ )
+ else:
+ self.mask_token = None
+
+ if self.config.use_absolute_position_embeddings:
+ self.position_embeddings = self.add_weight(
+ shape=(1, self.num_patches + 1, self.config.hidden_size),
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+ trainable=True,
+ name="position_embeddings",
+ )
+ else:
+ self.position_embeddings = None
+
+ super().build(input_shape)
+
+ def call(self, pixel_values: tf.Tensor, bool_masked_pos: Optional[tf.Tensor] = None) -> tf.Tensor:
+
+ embeddings = self.patch_embeddings(pixel_values)
+ batch_size, seq_len, projection_dim = shape_list(embeddings)
+
+ cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1))
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim))
+ # replace the masked visual tokens by mask_tokens
+ w = bool_masked_pos[..., None]
+ w = tf.cast(w, mask_tokens.dtype)
+ # since TF doesn't support eager tensor assignment
+ embeddings = embeddings * (1 - w) + mask_tokens * w
+
+ embeddings = tf.concat([cls_tokens, embeddings], axis=1)
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+class TFPatchEmbeddings(tf.keras.layers.Layer):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, image_size: int = 224, patch_size: int = 16, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+
+ image_size = (
+ config.image_size
+ if isinstance(config.image_size, collections.abc.Iterable)
+ else (config.image_size, config.image_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.patch_shape = patch_shape
+ self.num_channels = config.num_channels
+ self.embed_dim = config.hidden_size
+
+ self.projection = tf.keras.layers.Conv2D(
+ filters=self.embed_dim,
+ kernel_size=self.patch_size,
+ strides=self.patch_size,
+ padding="valid",
+ data_format="channels_last",
+ kernel_initializer="glorot_uniform", # following torch.nn.Linear
+ bias_initializer="zeros",
+ name="projection",
+ )
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+ batch_size, num_channels, height, width = shape_list(pixel_values)
+ if getattr(height, "numpy", None) and getattr(width, "numpy", None):
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ projection = self.projection(pixel_values)
+
+ # Change the 2D spatial dimensions to a single temporal dimension.
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
+
+ return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
+
+
+class TFData2VecVisionSelfAttention(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = tf.keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = tf.keras.layers.Dense(
+ units=self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="key",
+ use_bias=False,
+ )
+ self.value = tf.keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+ if window_size:
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+ config, window_size=window_size, name="relative_position_bias"
+ )
+ else:
+ self.relative_position_bias = None
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+ mixed_key_layer = self.key(inputs=hidden_states)
+ mixed_value_layer = self.value(inputs=hidden_states)
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ attention_scores = attention_scores / self.sqrt_att_head_size
+
+ # Add relative position bias if present.
+ if self.relative_position_bias is not None:
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
+ attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...]
+
+ # Add shared relative position bias if provided.
+ if relative_position_bias is not None:
+ attention_scores = attention_scores + relative_position_bias
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ return outputs
+
+
+class TFData2VecVisionSelfOutput(tf.keras.layers.Layer):
+ """
+ The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due
+ to the layernorm applied before each block.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = tf.keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+
+class TFData2VecVisionAttention(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention")
+ self.dense_output = TFData2VecVisionSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ self_outputs = self.attention(
+ hidden_states=input_tensor,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ training=training,
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
+class TFData2VecVisionIntermediate(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = tf.keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class TFData2VecVisionOutput(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = tf.keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+
+class TFData2VecVisionLayer(tf.keras.layers.Layer):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(
+ self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention")
+ self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate")
+ self.data2vec_output = TFData2VecVisionOutput(config, name="output")
+
+ self.layernorm_before = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_before"
+ )
+ self.layernorm_after = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_after"
+ )
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFDropPath(drop_path_rate, name="drop_path")
+ if drop_path_rate > 0.0
+ else tf.keras.layers.Activation("linear", name="drop_path")
+ )
+ self.init_values = config.layer_scale_init_value
+
+ def build(self, input_shape: tf.TensorShape):
+ if self.init_values > 0:
+ self.lambda_1 = self.add_weight(
+ shape=(self.config.hidden_size),
+ initializer="ones",
+ trainable=True,
+ name="lambda_1",
+ )
+ self.lambda_2 = self.add_weight(
+ shape=(self.config.hidden_size),
+ initializer="ones",
+ trainable=True,
+ name="lambda_2",
+ )
+ self.lambda_1.assign(self.init_values * tf.ones((self.config.hidden_size)))
+ self.lambda_2.assign(self.init_values * tf.ones((self.config.hidden_size)))
+ else:
+ self.lambda_1, self.lambda_2 = None, None
+
+ super().build(input_shape)
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: tf.Tensor,
+ output_attentions: bool,
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ self_attention_outputs = self.attention(
+ # in Data2VecVision, layernorm is applied before self-attention
+ input_tensor=self.layernorm_before(inputs=hidden_states),
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ relative_position_bias=relative_position_bias,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # apply lambda_1 if present
+ if self.lambda_1 is not None:
+ attention_output = self.lambda_1 * attention_output
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Data2VecVision, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+
+ layer_output = self.intermediate(layer_output)
+ layer_output = self.data2vec_output(layer_output)
+
+ if self.lambda_2 is not None:
+ layer_output = self.lambda_2 * layer_output
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Taken and modified from here:
+# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28
+class TFData2VecVisionRelativePositionBias(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.window_size = window_size
+ # +3 for cls_token_pos_len
+ # window_size can be something like (14, 14)
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+
+ self.relative_position_index = self.get_position_index()
+
+ def build(self, input_shape):
+ self.relative_position_bias_table = self.add_weight(
+ shape=(self.num_relative_distance, self.config.num_attention_heads),
+ initializer="zeros",
+ trainable=True,
+ name="relative_position_bias_table",
+ ) # [2*Wh-1 * 2*Ww-1, nH]
+ # cls to token & token 2 cls & cls to cls
+
+ super().build(input_shape)
+
+ def get_position_index(self):
+ # get pair-wise relative position index for each token inside the window
+ xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1]))
+ coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww]
+ coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww]
+
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww]
+ relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2]
+
+ xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
+ yy = relative_coords[:, :, 1] + self.window_size[1] - 1
+ relative_coords = tf.stack([xx, yy], axis=-1)
+
+ relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww]
+
+ top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * (
+ self.num_relative_distance - 3
+ )
+ left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * (
+ self.num_relative_distance - 2
+ )
+ corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1)
+
+ left_corner = tf.concat([corner, left], axis=0)
+ relative_position_index = tf.concat([top, relative_position_index], axis=0)
+ relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1]
+ return relative_position_index
+
+ def call(self, inputs=None) -> tf.Tensor:
+ relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0)
+ return tf.transpose(relative_position_bias, [2, 0, 1])
+
+
+class TFData2VecVisionEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ if config.use_shared_relative_position_bias:
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+ config, window_size=window_size, name="relative_position_bias"
+ )
+ else:
+ self.relative_position_bias = None
+
+ # stochastic depth decay rule
+ dpr = [x for x in tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers)]
+ self.layer = [
+ TFData2VecVisionLayer(
+ config,
+ window_size=window_size if config.use_relative_position_bias else None,
+ drop_path_rate=dpr[i],
+ name=f"layer_._{i}",
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, TFBaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
+ relative_position_bias = (
+ self.relative_position_bias(0.0) if self.relative_position_bias is not None else None
+ )
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@keras_serializable
+class TFData2VecVisionMainLayer(tf.keras.layers.Layer):
+ config_class = Data2VecVisionConfig
+
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.add_pooling_layer = add_pooling_layer
+
+ self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings")
+ self.encoder = TFData2VecVisionEncoder(
+ config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder"
+ )
+ self.layernorm = (
+ tf.identity
+ if config.use_mean_pooling
+ else tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ )
+
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> tf.keras.layers.Layer:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return TFData2VecVisionModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class TFData2VecVisionPooler(tf.keras.layers.Layer):
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.layernorm = (
+ tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ if config.use_mean_pooling
+ else None
+ )
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ if self.layernorm is not None:
+ # Mean pool the final hidden states of the patch tokens
+ patch_tokens = hidden_states[:, 1:, :]
+ pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1))
+ else:
+ # Pool by simply taking the final hidden state of the [CLS] token
+ pooled_output = hidden_states[:, 0]
+
+ return pooled_output
+
+
+class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Data2VecVisionConfig
+ base_model_prefix = "data2vec_vision"
+ main_input_name = "pixel_values"
+ _keys_to_ignore_on_load_unexpected = [r"relative_position_index"]
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network. Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(
+ shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size),
+ dtype=tf.float32,
+ )
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ """
+ Method used for serving the model.
+
+ Args:
+ inputs (`Dict[str, tf.Tensor]`):
+ The input of the saved model as a dictionary of tensors.
+ """
+
+ return self.call(inputs)
+
+
+DATA2VEC_VISION_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.).
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+
+
+ Args:
+ config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`BeitFeatureExtractor`]. See
+ [`BeitFeatureExtractor.__call__`] for details.
+
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+
+ training (`bool`, *optional*, defaults to `False``):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+
+ self.data2vec_vision = TFData2VecVisionMainLayer(
+ config, add_pooling_layer=add_pooling_layer, name="data2vec_vision"
+ )
+
+ def get_input_embeddings(self):
+ return self.data2vec_vision.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFData2VecVisionModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: Optional[TFModelInputType] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
+
+ outputs = self.data2vec_vision(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+
+@add_start_docstrings(
+ """
+ Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
+ the final hidden states of the patch tokens) e.g. for ImageNet.
+ """,
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision")
+
+ # Classifier head
+ self.classifier = tf.keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="classifier",
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: Optional[TFModelInputType] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFSequenceClassifierOutput, tuple]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.data2vec_vision(
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+ logits = self.classifier(pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class TFData2VecVisionConvModule(tf.keras.layers.Layer):
+ """
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ padding: str = "valid",
+ bias: bool = False,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.conv = tf.keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ use_bias=bias,
+ dilation_rate=dilation,
+ name="conv",
+ )
+ self.bn = tf.keras.layers.BatchNormalization(name="bn")
+ self.activation = tf.nn.relu
+
+ def call(self, input: tf.Tensor) -> tf.Tensor:
+ output = self.conv(input)
+ output = self.bn(output)
+ output = self.activation(output)
+ return output
+
+
+# Copied from:
+# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb
+class TFAdaptiveAvgPool1D(tf.keras.layers.Layer):
+ def __init__(self, output_dim, mode="dense", **kwargs):
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.mode = mode
+ self.map = None
+
+ def build(self, input_shape):
+ super().build(input_shape)
+ """We pre-compute the sparse matrix for the build() step once. The below code comes
+ from https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993."""
+
+ def get_kernels(ind, outd) -> List:
+ """Returns a List [(kernel_offset_start,kernel_length)] defining all the pooling kernels for a 1-D adaptive
+ pooling layer that takes an input of dimension `ind` and yields an output of dimension `outd`"""
+
+ def start_index(a, b, c):
+ return math.floor((float(a) * float(c)) / b)
+
+ def end_index(a, b, c):
+ return math.ceil((float(a + 1) * float(c)) / b)
+
+ results = []
+ for ow in range(outd):
+ start = start_index(ow, outd, ind)
+ end = end_index(ow, outd, ind)
+ sz = end - start
+ results.append((start, sz))
+ return results
+
+ in_dim = int(input_shape[-1])
+ kernels = get_kernels(in_dim, self.output_dim)
+ sparse_map = np.zeros((in_dim, self.output_dim), dtype=np.float32)
+ for i, kernel in enumerate(kernels):
+ sparse_map[kernel[0] : kernel[0] + kernel[1], i] = 1 / kernel[1]
+ if self.mode == "dense":
+ self.map = tf.constant(sparse_map)
+ else:
+ self.map = tf.sparse.from_dense(sparse_map)
+
+ def call(self, inputs):
+ if self.mode == "dense":
+ return inputs @ self.map
+ else:
+ input_dims = inputs.shape
+ input_matrix = tf.reshape(inputs, (-1, input_dims[-1]))
+ out = tf.sparse.sparse_dense_matmul(input_matrix, self.map)
+ return tf.reshape(out, input_dims[:-1].as_list() + [-1])
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"output_dim": self.output_dim, "mode": self.mode})
+ return config
+
+
+class TFAdaptiveAvgPool2D(tf.keras.layers.Layer):
+ def __init__(self, output_shape, mode="dense", **kwargs):
+ super().__init__(**kwargs)
+ self.mode = mode
+ self.h_pool = TFAdaptiveAvgPool1D(output_shape[0], mode=mode, name="h_pool")
+ self.w_pool = TFAdaptiveAvgPool1D(output_shape[1], mode=mode, name="w_pool")
+
+ def call(self, inputs):
+ # Rearrange from NHWC -> NCHW
+ inputs = tf.transpose(inputs, perm=[0, 3, 1, 2])
+ # Perform W-pooling
+ inputs = self.w_pool(inputs)
+ # Rearrange NCHW -> NCWH
+ inputs = tf.transpose(inputs, perm=[0, 1, 3, 2])
+ # Perform H-pooling
+ inputs = self.h_pool(inputs)
+ # Rearrange from NCWH -> NHWC
+ inputs = tf.transpose(inputs, perm=[0, 3, 2, 1])
+ return inputs
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"mode": self.mode})
+ return config
+
+
+class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
+ """
+ Pyramid Pooling Module (PPM) used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ channels (int): Channels after modules, before conv_seg.
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, pool_scales: Tuple[int, ...], channels: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.pool_scales = pool_scales
+ self.channels = channels
+
+ self.layer_list = []
+ for idx, pool_scale in enumerate(pool_scales):
+ pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
+ self.layer_list.append(
+ [
+ TFAdaptiveAvgPool2D(output_shape=pool_scale),
+ TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"{idx}.1"),
+ ]
+ )
+
+ def call(self, x: tf.Tensor) -> List[tf.Tensor]:
+ ppm_outs = []
+ inputs = x
+
+ for ppm in self.layer_list:
+ for layer_module in ppm:
+ ppm_out = layer_module(x)
+ x = ppm_out
+
+ upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+class TFData2VecVisionUperHead(tf.keras.layers.Layer):
+ """
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+ [UPerNet](https://arxiv.org/abs/1807.10221).
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
+ self.channels = config.hidden_size
+ self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+ # PSP Module
+ self.psp_modules = TFData2VecVisionPyramidPoolingModule(self.pool_scales, self.channels, name="psp_modules")
+ self.bottleneck = TFData2VecVisionConvModule(self.channels, kernel_size=3, padding="same", name="bottleneck")
+ # FPN Module
+ self.lateral_convs = []
+ self.fpn_convs = []
+ for idx, _ in enumerate(self.in_channels[:-1]): # skip the top layer
+ l_conv = TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}")
+ fpn_conv = TFData2VecVisionConvModule(
+ out_channels=self.channels, kernel_size=3, padding="same", name=f"fpn_convs.{idx}"
+ )
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = TFData2VecVisionConvModule(
+ out_channels=self.channels, kernel_size=3, padding="same", name="fpn_bottleneck"
+ )
+
+ def psp_forward(self, inputs):
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = tf.concat(psp_outs, axis=-1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+ # build laterals
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+ laterals.append(self.psp_forward(encoder_hidden_states))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = shape_list(laterals[i - 1])[1:-1]
+ laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
+
+ # build outputs
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
+ fpn_outs = tf.concat(fpn_outs, axis=-1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.classifier(output)
+
+ return output
+
+
+class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
+ """
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented from
+ [FCNNet](https://arxiv.org/abs/1411.4038).
+
+ Args:
+ config (Data2VecVisionConfig): Configuration.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+
+
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+ """
+
+ def __init__(
+ self,
+ config: Data2VecVisionConfig,
+ in_index: int = 2,
+ kernel_size: int = 3,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = config.hidden_size
+ self.channels = config.auxiliary_channels
+ self.num_convs = config.auxiliary_num_convs
+ self.concat_input = config.auxiliary_concat_input
+ self.in_index = in_index
+
+ convs = []
+ convs.append(
+ TFData2VecVisionConvModule(
+ out_channels=self.channels,
+ kernel_size=kernel_size,
+ padding="same",
+ dilation=dilation,
+ name="convs.0",
+ )
+ )
+ for i in range(self.num_convs - 1):
+ convs.append(
+ TFData2VecVisionConvModule(
+ out_channels=self.channels,
+ kernel_size=kernel_size,
+ padding="same",
+ dilation=dilation,
+ name=f"conv_module_{i+2}",
+ )
+ )
+ if self.num_convs == 0:
+ self.convs = [tf.identity]
+ else:
+ self.convs = convs
+ if self.concat_input:
+ self.conv_cat = TFData2VecVisionConvModule(
+ out_channels=self.channels, kernel_size=kernel_size, padding="same", name="conv_cat"
+ )
+
+ self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+ # just take the relevant feature maps
+ hidden_states = encoder_hidden_states[self.in_index]
+ output = hidden_states
+ for layer_module in self.convs:
+ output = layer_module(output)
+ if self.concat_input:
+ output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
+ output = self.classifier(output)
+ return output
+
+
+@add_start_docstrings(
+ """
+ Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
+ """,
+ DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
+
+ # FPNs
+ self.fpn1 = [
+ tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
+ tf.keras.layers.BatchNormalization(name="fpn1.1"),
+ tf.keras.layers.Activation("gelu"),
+ tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
+ ]
+ self.fpn2 = [tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
+
+ self.fpn3 = tf.identity
+ self.fpn4 = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)
+
+ # Semantic segmentation head(s)
+ self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
+ self.auxiliary_head = (
+ TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
+ )
+
+ def compute_loss(self, logits, auxiliary_logits, labels):
+ # upsample logits to the images' original size
+ if len(shape_list(labels)) > 3:
+ label_interp_shape = shape_list(labels)[1:-1]
+ else:
+ label_interp_shape = shape_list(labels)[-2:]
+
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
+ if auxiliary_logits is not None:
+ upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
+ # compute weighted loss
+ loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
+
+ # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
+ # Utility to mask the index to ignore during computing the loss.
+ def masked_loss(real, pred):
+ mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
+ loss_ = loss_fct(real, pred)
+ mask = tf.cast(mask, dtype=loss_.dtype)
+ loss_ *= mask
+
+ return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
+
+ main_loss = masked_loss(labels, upsampled_logits)
+ auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+ return loss
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, TFSemanticSegmenterOutput]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, TFData2VecVisionForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/data2vec-vision-base")
+ >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> # logits are of shape (batch_size, num_labels, height, width)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.data2vec_vision(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ # only keep certain features, and reshape
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
+ batch_size = shape_list(pixel_values)[0]
+ patch_resolution = self.config.image_size // self.config.patch_size
+
+ def reshape_features(x):
+ x = tf.reshape(x, (batch_size, patch_resolution, patch_resolution, -1))
+ return x
+
+ features = [reshape_features(x[:, 1:, :]) for x in features]
+
+ # apply FPNs
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+ for module in ops[0]:
+ features[0] = module(features[0])
+ features[1] = ops[1][0](features[1])
+ for i in range(len(features[2:])):
+ features[i + 2] = ops[i + 2](features[i + 2])
+
+ logits = self.decode_head(features)
+ # Tranpose the logits to maintain consistency in the output formats.
+ transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(features)
+
+ loss = None
+ if labels is not None:
+ if self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+ else:
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSemanticSegmenterOutput(
+ loss=loss,
+ logits=transposed_logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/deberta/__init__.py b/src/transformers/models/deberta/__init__.py
index 38a3a22d51ccc6..8c8ebc127e07fd 100644
--- a/src/transformers/models/deberta/__init__.py
+++ b/src/transformers/models/deberta/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_deberta": ["DebertaTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_deberta_fast"] = ["DebertaTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_deberta"] = [
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForMaskedLM",
@@ -40,7 +56,12 @@
"DebertaPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_deberta"] = [
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaForMaskedLM",
@@ -56,10 +77,20 @@
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from .tokenization_deberta import DebertaTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_deberta_fast import DebertaTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
@@ -70,7 +101,12 @@
DebertaPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_deberta import (
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaForMaskedLM,
diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py
index a1df51ac63878b..e66241bd56c921 100644
--- a/src/transformers/models/deberta/modeling_deberta.py
+++ b/src/transformers/models/deberta/modeling_deberta.py
@@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
- rmask = ~(mask.bool())
+ rmask = ~(mask.to(torch.bool))
- output = input.masked_fill(rmask, float("-inf"))
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
@@ -129,7 +129,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
- output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
@@ -152,7 +152,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
- mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
@@ -564,7 +564,7 @@ def __init__(self, config):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -652,7 +652,7 @@ def linear(w, b, x):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py
index 2b369eef5d0b0e..7099ca0c5bcd59 100644
--- a/src/transformers/models/deberta/modeling_tf_deberta.py
+++ b/src/transformers/models/deberta/modeling_tf_deberta.py
@@ -648,7 +648,12 @@ def linear(w, b, x):
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
- new_context_layer_shape = shape_list(context_layer)[:-2] + [-1]
+ context_layer_shape = shape_list(context_layer)
+ # Set the final dimension here explicitly.
+ # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
+ # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
+ # requires final input dimension to be defined
+ new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
@@ -773,7 +778,8 @@ def call(
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
- assert not (input_ids is None and inputs_embeds is None)
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py
index 6bca0ed581bacc..0ff9359fb0e7d9 100644
--- a/src/transformers/models/deberta/tokenization_deberta.py
+++ b/src/transformers/models/deberta/tokenization_deberta.py
@@ -32,7 +32,9 @@
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/vocab.json",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/vocab.json",
- "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json",
+ "microsoft/deberta-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
@@ -40,7 +42,9 @@
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/merges.txt",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/merges.txt",
- "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt",
+ "microsoft/deberta-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt"
+ ),
},
}
@@ -210,7 +214,7 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
- return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
diff --git a/src/transformers/models/deberta/tokenization_deberta_fast.py b/src/transformers/models/deberta/tokenization_deberta_fast.py
index 74c2e4aca2a8a0..5b3852a6ed304f 100644
--- a/src/transformers/models/deberta/tokenization_deberta_fast.py
+++ b/src/transformers/models/deberta/tokenization_deberta_fast.py
@@ -33,7 +33,9 @@
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/vocab.json",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/vocab.json",
- "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json",
+ "microsoft/deberta-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
@@ -41,7 +43,9 @@
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/merges.txt",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/merges.txt",
- "microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt",
+ "microsoft/deberta-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt"
+ ),
},
}
@@ -183,7 +187,7 @@ def create_token_type_ids_from_sequences(
sequence pair mask has the following format:
```
- 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
@@ -203,4 +207,4 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
- return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
diff --git a/src/transformers/models/deberta_v2/__init__.py b/src/transformers/models/deberta_v2/__init__.py
index a7f3cada936f42..1436f257b31737 100644
--- a/src/transformers/models/deberta_v2/__init__.py
+++ b/src/transformers/models/deberta_v2/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_deberta_v2": ["DebertaV2Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_deberta_v2_fast"] = ["DebertaV2TokenizerFast"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_deberta_v2"] = [
"TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDebertaV2ForMaskedLM",
@@ -40,10 +56,16 @@
"TFDebertaV2PreTrainedModel",
]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_deberta_v2"] = [
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaV2ForMaskedLM",
+ "DebertaV2ForMultipleChoice",
"DebertaV2ForQuestionAnswering",
"DebertaV2ForSequenceClassification",
"DebertaV2ForTokenClassification",
@@ -56,10 +78,20 @@
from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from .tokenization_deberta_v2 import DebertaV2Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_deberta_v2 import (
TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDebertaV2ForMaskedLM,
@@ -70,10 +102,16 @@
TFDebertaV2PreTrainedModel,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_deberta_v2 import (
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py
index 0f6f268c385224..7b81f146b9573c 100644
--- a/src/transformers/models/deberta_v2/configuration_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py
@@ -23,8 +23,12 @@
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/config.json",
"microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/config.json",
- "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/config.json",
- "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/config.json",
+ "microsoft/deberta-v2-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/config.json"
+ ),
+ "microsoft/deberta-v2-xxlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
index c779267b7b38ed..3e57666acfdd60 100644
--- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
@@ -27,6 +27,7 @@
from ...modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
+ MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
@@ -106,9 +107,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
- rmask = ~(mask.bool())
+ rmask = ~(mask.to(torch.bool))
- output = input.masked_fill(rmask, float("-inf"))
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
@@ -131,7 +132,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
- output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
@@ -156,7 +157,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
- mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
@@ -637,7 +638,7 @@ def __init__(self, config):
def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
@@ -718,7 +719,7 @@ def forward(
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
@@ -1511,3 +1512,106 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, 1)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.deberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
index 5012aacf44dd3f..82f2c6a3ce210c 100644
--- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
@@ -604,14 +604,14 @@ def __init__(self, config: DebertaV2Config, **kwargs):
if not self.share_att_key:
if "c2p" in self.pos_att_type:
- self.pos_proj = tf.keras.layers.Dense(
+ self.pos_key_proj = tf.keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="pos_proj",
use_bias=True,
)
if "p2c" in self.pos_att_type:
- self.pos_q_proj = tf.keras.layers.Dense(
+ self.pos_query_proj = tf.keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="pos_q_proj",
@@ -620,11 +620,15 @@ def __init__(self, config: DebertaV2Config, **kwargs):
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
- shape = shape_list(tensor)[:-1] + [attention_heads, -1]
+ tensor_shape = shape_list(tensor)
+ # In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
+ shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=shape)
+ tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
x_shape = shape_list(tensor)
- return tf.reshape(tf.transpose(tensor, perm=[0, 2, 1, 3]), shape=[-1, x_shape[1], x_shape[-1]])
+ tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
+ return tensor
def call(
self,
@@ -686,7 +690,6 @@ def call(
if rel_att is not None:
attention_scores = attention_scores + rel_att
- attention_scores = attention_scores
attention_scores = tf.reshape(
attention_scores,
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
@@ -706,9 +709,12 @@ def call(
),
[0, 2, 1, 3],
)
- new_context_layer_shape = shape_list(context_layer)[:-2] + [
- -1,
- ]
+ # Set the final dimension here explicitly.
+ # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
+ # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
+ # requires final input dimension to be defined
+ context_layer_shape = shape_list(context_layer)
+ new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
@@ -870,7 +876,8 @@ def call(
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
- assert not (input_ids is None and inputs_embeds is None)
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py
index 577532e1becf4c..123afacf822ca9 100644
--- a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py
@@ -28,8 +28,12 @@
"vocab_file": {
"microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model",
"microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model",
- "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model",
- "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model",
+ "microsoft/deberta-v2-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model"
+ ),
+ "microsoft/deberta-v2-xxlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model"
+ ),
}
}
@@ -137,8 +141,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py
index 8aa92180d651f9..32ccd84862fa86 100644
--- a/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py
+++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py
@@ -36,8 +36,12 @@
"vocab_file": {
"microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model",
"microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model",
- "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model",
- "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model",
+ "microsoft/deberta-v2-xlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model"
+ ),
+ "microsoft/deberta-v2-xxlarge-mnli": (
+ "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model"
+ ),
}
}
diff --git a/src/transformers/models/decision_transformer/__init__.py b/src/transformers/models/decision_transformer/__init__.py
index 8a72ff89c17c7b..b4b083af04b659 100644
--- a/src/transformers/models/decision_transformer/__init__.py
+++ b/src/transformers/models/decision_transformer/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -28,7 +28,12 @@
],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_decision_transformer"] = [
"DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DecisionTransformerGPT2Model",
@@ -44,7 +49,12 @@
DecisionTransformerConfig,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_decision_transformer import (
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
DecisionTransformerGPT2Model,
diff --git a/src/transformers/models/decision_transformer/configuration_decision_transformer.py b/src/transformers/models/decision_transformer/configuration_decision_transformer.py
index 389cb0d3021a0d..01c74c247b8e71 100644
--- a/src/transformers/models/decision_transformer/configuration_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/configuration_decision_transformer.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "edbeeching/decision-transformer-gym-hopper-medium": "https://huggingface.co/edbeeching/decision-transformer-gym-hopper-medium/resolve/main/config.json",
+ "edbeeching/decision-transformer-gym-hopper-medium": (
+ "https://huggingface.co/edbeeching/decision-transformer-gym-hopper-medium/resolve/main/config.json"
+ ),
# See all DecisionTransformer models at https://huggingface.co/models?filter=decision_transformer
}
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index b36ece31b731b7..509d7250b719df 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -17,7 +17,7 @@
import math
import os
from dataclasses import dataclass
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -135,7 +135,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
@@ -186,7 +187,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
@@ -278,20 +279,20 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
def forward(
self,
- hidden_states,
- layer_past=None,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
- "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
+ "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to"
+ " instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
@@ -340,7 +341,7 @@ def __init__(self, intermediate_size, config):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
@@ -369,15 +370,15 @@ def __init__(self, config, layer_idx=None):
def forward(
self,
- hidden_states,
- layer_past=None,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
@@ -510,20 +511,20 @@ def set_input_embeddings(self, new_embeddings):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
diff --git a/src/transformers/models/deit/__init__.py b/src/transformers/models/deit/__init__.py
index 913e53f9ae8632..6c82e1aaaf5898 100644
--- a/src/transformers/models/deit/__init__.py
+++ b/src/transformers/models/deit/__init__.py
@@ -17,17 +17,25 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"],
-}
+_import_structure = {"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_deit"] = ["DeiTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_deit"] = [
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DeiTForImageClassification",
@@ -41,10 +49,20 @@
if TYPE_CHECKING:
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_deit import DeiTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_deit import (
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
DeiTForImageClassification,
diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py
index 022df1727f5830..df74664ace6133 100644
--- a/src/transformers/models/deit/configuration_deit.py
+++ b/src/transformers/models/deit/configuration_deit.py
@@ -27,7 +27,9 @@
logger = logging.get_logger(__name__)
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/deit-base-distilled-patch16-224": "https://huggingface.co/facebook/deit-base-patch16-224/resolve/main/config.json",
+ "facebook/deit-base-distilled-patch16-224": (
+ "https://huggingface.co/facebook/deit-base-patch16-224/resolve/main/config.json"
+ ),
# See all DeiT models at https://huggingface.co/models?filter=deit
}
diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py
index 94bf5dcfbe487c..ac429c0a615fc0 100644
--- a/src/transformers/models/deit/modeling_deit.py
+++ b/src/transformers/models/deit/modeling_deit.py
@@ -168,7 +168,7 @@ def __init__(self, config: DeiTConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -200,7 +200,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -570,7 +570,8 @@ def forward(self, hidden_states):
@add_start_docstrings(
- "DeiT Model with a decoder on top for masked image modeling, as proposed in `SimMIM `__.",
+ "DeiT Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
+ " `__.",
DEIT_START_DOCSTRING,
)
class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
diff --git a/src/transformers/models/detr/__init__.py b/src/transformers/models/detr/__init__.py
index c7165128245dd3..5958418807d0aa 100644
--- a/src/transformers/models/detr/__init__.py
+++ b/src/transformers/models/detr/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_timm_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
-_import_structure = {
- "configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"],
-}
+_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_detr"] = ["DetrFeatureExtractor"]
-if is_timm_available():
+try:
+ if not is_timm_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_detr"] = [
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetrForObjectDetection",
@@ -41,10 +49,20 @@
if TYPE_CHECKING:
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_detr import DetrFeatureExtractor
- if is_timm_available():
+ try:
+ if not is_timm_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py
index 15b37fbae7d34f..91e406c71fc944 100644
--- a/src/transformers/models/detr/feature_extraction_detr.py
+++ b/src/transformers/models/detr/feature_extraction_detr.py
@@ -538,7 +538,8 @@ def __call__(
valid_masks_path = True
if not valid_masks_path:
raise ValueError(
- "The path to the directory containing the mask PNG files should be provided as a `pathlib.Path` object."
+ "The path to the directory containing the mask PNG files should be provided as a"
+ " `pathlib.Path` object."
)
if not is_batched:
diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py
index b787aebc8aa6f1..d261104ac7ad26 100644
--- a/src/transformers/models/detr/modeling_detr.py
+++ b/src/transformers/models/detr/modeling_detr.py
@@ -489,7 +489,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
@@ -553,7 +554,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -582,7 +584,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -1714,7 +1717,8 @@ def __init__(self, dim, fpn_dims, context_dim):
if dim % 8 != 0:
raise ValueError(
- "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in GroupNorm is set to 8"
+ "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
+ " GroupNorm is set to 8"
)
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
@@ -1865,30 +1869,31 @@ class DetrLoss(nn.Module):
"""
This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
- of matched ground-truth / prediction (supervise class and box)
+ of matched ground-truth / prediction (supervise class and box).
+
+ A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
+ parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
+ the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
+ be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
+ (`max_obj_id` + 1). For more details on this, check the following discussion
+ https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
+
+
+ Args:
+ matcher (`DetrHungarianMatcher`):
+ Module able to compute a matching between targets and proposals.
+ num_classes (`int`):
+ Number of object categories, omitting the special no-object category.
+ eos_coef (`float`):
+ Relative classification weight applied to the no-object category.
+ losses (`List[str]`):
+ List of all the losses to be applied. See `get_loss` for a list of all available losses.
"""
def __init__(self, matcher, num_classes, eos_coef, losses):
- """
- Create the criterion.
-
- A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
- parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
- is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
- `num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
- `num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
- https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
-
- Parameters:
- matcher: module able to compute a matching between targets and proposals.
- num_classes: number of object categories, omitting the special no-object category.
- weight_dict: dict containing as key the names of the losses and as values their relative weight.
- eos_coef: relative classification weight applied to the no-object category.
- losses: list of all the losses to be applied. See get_loss for list of available losses.
- """
super().__init__()
- self.num_classes = num_classes
self.matcher = matcher
+ self.num_classes = num_classes
self.eos_coef = eos_coef
self.losses = losses
empty_weight = torch.ones(self.num_classes + 1)
@@ -2017,10 +2022,12 @@ def forward(self, outputs, targets):
"""
This performs the loss computation.
- Parameters:
- outputs: dict of tensors, see the output specification of the model for the format
- targets: list of dicts, such that len(targets) == batch_size.
- The expected keys in each dict depends on the losses applied, see each loss' doc
+ Args:
+ outputs (`dict`, *optional*):
+ Dictionary of tensors, see the output specification of the model for the format.
+ targets (`List[dict]`, *optional*):
+ List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
+ losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
@@ -2086,20 +2093,18 @@ class DetrHungarianMatcher(nn.Module):
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
+
+ Args:
+ class_cost:
+ The relative weight of the classification error in the matching cost.
+ bbox_cost:
+ The relative weight of the L1 error of the bounding box coordinates in the matching cost.
+ giou_cost:
+ The relative weight of the giou loss of the bounding box in the matching cost.
"""
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
- """
- Creates the matcher.
-
- Params:
- class_cost: This is the relative weight of the classification error in the matching cost
- bbox_cost:
- This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
- giou_cost: This is the relative weight of the giou loss of the bounding box in the matching cost
- """
super().__init__()
-
requires_backends(self, ["scipy"])
self.class_cost = class_cost
@@ -2111,25 +2116,25 @@ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float
@torch.no_grad()
def forward(self, outputs, targets):
"""
- Performs the matching.
-
- Params:
- outputs: This is a dict that contains at least these entries:
- "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
- "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
- "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
- objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4]
- containing the target box coordinates
+ Args:
+ outputs (`dict`):
+ A dictionary that contains at least these entries:
+ * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
+ targets (`List[dict]`):
+ A list of targets (len(targets) = batch_size), where each target is a dict containing:
+ * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
+ ground-truth
+ objects in the target) containing the class labels
+ * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
Returns:
- A list of size batch_size, containing tuples of (index_i, index_j) where:
-
- - index_i is the indices of the selected predictions (in order)
- - index_j is the indices of the corresponding selected targets (in order)
+ `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
- bs, num_queries = outputs["logits"].shape[:2]
+ batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
@@ -2152,7 +2157,7 @@ def forward(self, outputs, targets):
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
- cost_matrix = cost_matrix.view(bs, num_queries, -1).cpu()
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
@@ -2175,11 +2180,12 @@ def box_area(boxes: Tensor) -> Tensor:
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
Args:
- boxes (Tensor[N, 4]): boxes for which the area will be computed. They
- are expected to be in (x1, y1, x2, y2) format with `0 <= x1 < x2` and `0 <= y1 < y2`.
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
+ < x2` and `0 <= y1 < y2`.
Returns:
- area (Tensor[N]): area for each box
+ `torch.FloatTensor`: a tensor containing the area for each box.
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
@@ -2190,11 +2196,11 @@ def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
- lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
- rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
- wh = (rb - lt).clamp(min=0) # [N,M,2]
- inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
@@ -2207,7 +2213,7 @@ def generalized_box_iou(boxes1, boxes2):
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
Returns:
- a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
@@ -2242,7 +2248,6 @@ def __init__(self, tensors, mask: Optional[Tensor]):
self.mask = mask
def to(self, device):
- # type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
diff --git a/src/transformers/models/distilbert/__init__.py b/src/transformers/models/distilbert/__init__.py
index fd2e7e6a2d51da..67d4502e26939c 100644
--- a/src/transformers/models/distilbert/__init__.py
+++ b/src/transformers/models/distilbert/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -30,10 +37,20 @@
"tokenization_distilbert": ["DistilBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_distilbert_fast"] = ["DistilBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_distilbert"] = [
"DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DistilBertForMaskedLM",
@@ -45,7 +62,12 @@
"DistilBertPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_distilbert"] = [
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDistilBertForMaskedLM",
@@ -58,7 +80,12 @@
"TFDistilBertPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_distilbert"] = [
"FlaxDistilBertForMaskedLM",
"FlaxDistilBertForMultipleChoice",
@@ -78,10 +105,20 @@
)
from .tokenization_distilbert import DistilBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_distilbert_fast import DistilBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_distilbert import (
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
DistilBertForMaskedLM,
@@ -93,7 +130,12 @@
DistilBertPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
@@ -106,7 +148,12 @@
TFDistilBertPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_distilbert import (
FlaxDistilBertForMaskedLM,
FlaxDistilBertForMultipleChoice,
diff --git a/src/transformers/models/distilbert/configuration_distilbert.py b/src/transformers/models/distilbert/configuration_distilbert.py
index 59752bbe7e1fc2..c746ad0d64ec78 100644
--- a/src/transformers/models/distilbert/configuration_distilbert.py
+++ b/src/transformers/models/distilbert/configuration_distilbert.py
@@ -25,12 +25,20 @@
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"distilbert-base-uncased": "https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
- "distilbert-base-uncased-distilled-squad": "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/config.json",
+ "distilbert-base-uncased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/config.json"
+ ),
"distilbert-base-cased": "https://huggingface.co/distilbert-base-cased/resolve/main/config.json",
- "distilbert-base-cased-distilled-squad": "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
+ "distilbert-base-cased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json"
+ ),
"distilbert-base-german-cased": "https://huggingface.co/distilbert-base-german-cased/resolve/main/config.json",
- "distilbert-base-multilingual-cased": "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/config.json",
- "distilbert-base-uncased-finetuned-sst-2-english": "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
+ "distilbert-base-multilingual-cased": (
+ "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/config.json"
+ ),
+ "distilbert-base-uncased-finetuned-sst-2-english": (
+ "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index 5ef86541f234d0..a93c08345a0bf5 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -211,7 +211,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
- scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
+ scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
diff --git a/src/transformers/models/distilbert/tokenization_distilbert.py b/src/transformers/models/distilbert/tokenization_distilbert.py
index 694c0ad25aa01a..9408ca0b0f6989 100644
--- a/src/transformers/models/distilbert/tokenization_distilbert.py
+++ b/src/transformers/models/distilbert/tokenization_distilbert.py
@@ -25,11 +25,17 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"distilbert-base-uncased": "https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt",
- "distilbert-base-uncased-distilled-squad": "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/vocab.txt",
+ "distilbert-base-uncased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/vocab.txt"
+ ),
"distilbert-base-cased": "https://huggingface.co/distilbert-base-cased/resolve/main/vocab.txt",
- "distilbert-base-cased-distilled-squad": "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/vocab.txt",
+ "distilbert-base-cased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/vocab.txt"
+ ),
"distilbert-base-german-cased": "https://huggingface.co/distilbert-base-german-cased/resolve/main/vocab.txt",
- "distilbert-base-multilingual-cased": "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/vocab.txt",
+ "distilbert-base-multilingual-cased": (
+ "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/distilbert/tokenization_distilbert_fast.py b/src/transformers/models/distilbert/tokenization_distilbert_fast.py
index 6a4ddfb81986c5..fdd69dc3e01aa6 100644
--- a/src/transformers/models/distilbert/tokenization_distilbert_fast.py
+++ b/src/transformers/models/distilbert/tokenization_distilbert_fast.py
@@ -26,19 +26,33 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"distilbert-base-uncased": "https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt",
- "distilbert-base-uncased-distilled-squad": "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/vocab.txt",
+ "distilbert-base-uncased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/vocab.txt"
+ ),
"distilbert-base-cased": "https://huggingface.co/distilbert-base-cased/resolve/main/vocab.txt",
- "distilbert-base-cased-distilled-squad": "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/vocab.txt",
+ "distilbert-base-cased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/vocab.txt"
+ ),
"distilbert-base-german-cased": "https://huggingface.co/distilbert-base-german-cased/resolve/main/vocab.txt",
- "distilbert-base-multilingual-cased": "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/vocab.txt",
+ "distilbert-base-multilingual-cased": (
+ "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
"distilbert-base-uncased": "https://huggingface.co/distilbert-base-uncased/resolve/main/tokenizer.json",
- "distilbert-base-uncased-distilled-squad": "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/tokenizer.json",
+ "distilbert-base-uncased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-uncased-distilled-squad/resolve/main/tokenizer.json"
+ ),
"distilbert-base-cased": "https://huggingface.co/distilbert-base-cased/resolve/main/tokenizer.json",
- "distilbert-base-cased-distilled-squad": "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/tokenizer.json",
- "distilbert-base-german-cased": "https://huggingface.co/distilbert-base-german-cased/resolve/main/tokenizer.json",
- "distilbert-base-multilingual-cased": "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/tokenizer.json",
+ "distilbert-base-cased-distilled-squad": (
+ "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/tokenizer.json"
+ ),
+ "distilbert-base-german-cased": (
+ "https://huggingface.co/distilbert-base-german-cased/resolve/main/tokenizer.json"
+ ),
+ "distilbert-base-multilingual-cased": (
+ "https://huggingface.co/distilbert-base-multilingual-cased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/dpr/__init__.py b/src/transformers/models/dpr/__init__.py
index 5ee5962258842d..8f9482364347fc 100644
--- a/src/transformers/models/dpr/__init__.py
+++ b/src/transformers/models/dpr/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -32,14 +38,24 @@
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_dpr_fast"] = [
"DPRContextEncoderTokenizerFast",
"DPRQuestionEncoderTokenizerFast",
"DPRReaderTokenizerFast",
]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_dpr"] = [
"DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -53,7 +69,12 @@
"DPRReader",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_dpr"] = [
"TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -76,14 +97,24 @@
DPRReaderTokenizer,
)
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_dpr_fast import (
DPRContextEncoderTokenizerFast,
DPRQuestionEncoderTokenizerFast,
DPRReaderTokenizerFast,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -97,7 +128,12 @@
DPRReader,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_dpr import (
TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
diff --git a/src/transformers/models/dpr/configuration_dpr.py b/src/transformers/models/dpr/configuration_dpr.py
index 0828f0a92cab21..799f9aae4e23c5 100644
--- a/src/transformers/models/dpr/configuration_dpr.py
+++ b/src/transformers/models/dpr/configuration_dpr.py
@@ -21,12 +21,24 @@
logger = logging.get_logger(__name__)
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/dpr-ctx_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/config.json",
- "facebook/dpr-question_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/config.json",
- "facebook/dpr-reader-single-nq-base": "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/config.json",
- "facebook/dpr-ctx_encoder-multiset-base": "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/config.json",
- "facebook/dpr-question_encoder-multiset-base": "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/config.json",
- "facebook/dpr-reader-multiset-base": "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/config.json",
+ "facebook/dpr-ctx_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/config.json"
+ ),
+ "facebook/dpr-question_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/config.json"
+ ),
+ "facebook/dpr-reader-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/config.json"
+ ),
+ "facebook/dpr-ctx_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/config.json"
+ ),
+ "facebook/dpr-question_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/config.json"
+ ),
+ "facebook/dpr-reader-multiset-base": (
+ "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py b/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
index c6484581b7e5f8..6ea85620242f06 100644
--- a/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
+++ b/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
@@ -124,7 +124,11 @@ def convert(comp_type: str, src_file: Path, dest_dir: Path):
parser.add_argument(
"--src",
type=str,
- help="Path to the dpr checkpoint file. They can be downloaded from the official DPR repo https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the 'retriever' checkpoints.",
+ help=(
+ "Path to the dpr checkpoint file. They can be downloaded from the official DPR repo"
+ " https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the"
+ " 'retriever' checkpoints."
+ ),
)
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.")
args = parser.parse_args()
diff --git a/src/transformers/models/dpr/tokenization_dpr.py b/src/transformers/models/dpr/tokenization_dpr.py
index 8edaf2d3d1b01b..208b9c377ed5c0 100644
--- a/src/transformers/models/dpr/tokenization_dpr.py
+++ b/src/transformers/models/dpr/tokenization_dpr.py
@@ -29,32 +29,56 @@
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-ctx_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-ctx_encoder-multiset-base": "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-ctx_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-ctx_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-ctx_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-ctx_encoder-multiset-base": "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-ctx_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-ctx_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-question_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-question_encoder-multiset-base": "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-question_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-question_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-question_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-question_encoder-multiset-base": "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-question_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-question_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
READER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-reader-single-nq-base": "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-reader-multiset-base": "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-reader-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-reader-multiset-base": (
+ "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-reader-single-nq-base": "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-reader-multiset-base": "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-reader-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-reader-multiset-base": (
+ "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
@@ -342,8 +366,8 @@ def _get_best_spans(
`span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
"""
scores = []
- for (start_index, start_score) in enumerate(start_logits):
- for (answer_length, end_score) in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
scores.append(((start_index, start_index + answer_length), start_score + end_score))
scores = sorted(scores, key=lambda x: x[1], reverse=True)
chosen_span_intervals = []
diff --git a/src/transformers/models/dpr/tokenization_dpr_fast.py b/src/transformers/models/dpr/tokenization_dpr_fast.py
index ea021dcb6ab163..486eb9f38707c6 100644
--- a/src/transformers/models/dpr/tokenization_dpr_fast.py
+++ b/src/transformers/models/dpr/tokenization_dpr_fast.py
@@ -30,32 +30,56 @@
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-ctx_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-ctx_encoder-multiset-base": "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-ctx_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-ctx_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-ctx_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-ctx_encoder-multiset-base": "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-ctx_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-ctx_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-ctx_encoder-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-question_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-question_encoder-multiset-base": "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-question_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-question_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-question_encoder-single-nq-base": "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-question_encoder-multiset-base": "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-question_encoder-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-question_encoder-multiset-base": (
+ "https://huggingface.co/facebook/dpr-question_encoder-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
READER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/dpr-reader-single-nq-base": "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/vocab.txt",
- "facebook/dpr-reader-multiset-base": "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/vocab.txt",
+ "facebook/dpr-reader-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/vocab.txt"
+ ),
+ "facebook/dpr-reader-multiset-base": (
+ "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "facebook/dpr-reader-single-nq-base": "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/tokenizer.json",
- "facebook/dpr-reader-multiset-base": "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/tokenizer.json",
+ "facebook/dpr-reader-single-nq-base": (
+ "https://huggingface.co/facebook/dpr-reader-single-nq-base/resolve/main/tokenizer.json"
+ ),
+ "facebook/dpr-reader-multiset-base": (
+ "https://huggingface.co/facebook/dpr-reader-multiset-base/resolve/main/tokenizer.json"
+ ),
},
}
@@ -342,8 +366,8 @@ def _get_best_spans(
`span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
"""
scores = []
- for (start_index, start_score) in enumerate(start_logits):
- for (answer_length, end_score) in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
scores.append(((start_index, start_index + answer_length), start_score + end_score))
scores = sorted(scores, key=lambda x: x[1], reverse=True)
chosen_span_intervals = []
diff --git a/src/transformers/models/dpt/__init__.py b/src/transformers/models/dpt/__init__.py
index ba895de9b9fd6f..1df82ab6282465 100644
--- a/src/transformers/models/dpt/__init__.py
+++ b/src/transformers/models/dpt/__init__.py
@@ -18,16 +18,25 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable
-_import_structure = {
- "configuration_dpt": ["DPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPTConfig"],
-}
+_import_structure = {"configuration_dpt": ["DPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPTConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_dpt"] = ["DPTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_dpt"] = [
"DPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPTForDepthEstimation",
@@ -40,10 +49,20 @@
if TYPE_CHECKING:
from .configuration_dpt import DPT_PRETRAINED_CONFIG_ARCHIVE_MAP, DPTConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_dpt import DPTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_dpt import (
DPT_PRETRAINED_MODEL_ARCHIVE_LIST,
DPTForDepthEstimation,
diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py
index 6c5fd2385232c0..64ea40a5c534f1 100755
--- a/src/transformers/models/dpt/modeling_dpt.py
+++ b/src/transformers/models/dpt/modeling_dpt.py
@@ -177,7 +177,7 @@ def __init__(self, config: DPTConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -209,7 +209,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py
index ad818da549271d..59e3ca47794173 100644
--- a/src/transformers/models/electra/__init__.py
+++ b/src/transformers/models/electra/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_electra": ["ElectraTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_electra_fast"] = ["ElectraTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_electra"] = [
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"ElectraForCausalLM",
@@ -44,7 +61,12 @@
"load_tf_weights_in_electra",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_electra"] = [
"TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFElectraForMaskedLM",
@@ -57,8 +79,14 @@
"TFElectraPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_electra"] = [
+ "FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
@@ -74,10 +102,20 @@
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
from .tokenization_electra import ElectraTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_electra_fast import ElectraTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForCausalLM,
@@ -92,7 +130,12 @@
load_tf_weights_in_electra,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
@@ -105,8 +148,14 @@
TFElectraPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_electra import (
+ FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py
index 765498ef833b17..3ea54aa7ca9a01 100644
--- a/src/transformers/models/electra/configuration_electra.py
+++ b/src/transformers/models/electra/configuration_electra.py
@@ -29,9 +29,15 @@
"google/electra-small-generator": "https://huggingface.co/google/electra-small-generator/resolve/main/config.json",
"google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
"google/electra-large-generator": "https://huggingface.co/google/electra-large-generator/resolve/main/config.json",
- "google/electra-small-discriminator": "https://huggingface.co/google/electra-small-discriminator/resolve/main/config.json",
- "google/electra-base-discriminator": "https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
- "google/electra-large-discriminator": "https://huggingface.co/google/electra-large-discriminator/resolve/main/config.json",
+ "google/electra-small-discriminator": (
+ "https://huggingface.co/google/electra-small-discriminator/resolve/main/config.json"
+ ),
+ "google/electra-base-discriminator": (
+ "https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json"
+ ),
+ "google/electra-large-discriminator": (
+ "https://huggingface.co/google/electra-large-discriminator/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py
index 0e8a5c59177938..d5d6376d7b9942 100644
--- a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py
@@ -59,8 +59,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained model. \n"
- "This specifies the model architecture.",
+ help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
@@ -70,8 +69,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
- "'generator'.",
+ help=(
+ "Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
+ "'generator'."
+ ),
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py
index 0c21d546ecc092..3f488fbcf5648b 100644
--- a/src/transformers/models/electra/modeling_electra.py
+++ b/src/transformers/models/electra/modeling_electra.py
@@ -174,8 +174,13 @@ def __init__(self, config):
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
- ):
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -238,7 +243,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -482,7 +487,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -664,7 +670,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
base_model_prefix = "electra"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
- _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
+ _keys_to_ignore_on_load_unexpected = [r"electra.embeddings_project.weight", r"electra.embeddings_project.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
@@ -850,7 +856,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -882,7 +888,7 @@ def forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -985,7 +991,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutput]:
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1075,7 +1081,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, ElectraForPreTrainingOutput]:
+ ) -> Union[Tuple[torch.Tensor], ElectraForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
@@ -1197,7 +1203,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MaskedLMOutput]:
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1283,7 +1289,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1368,7 +1374,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
@@ -1469,7 +1475,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MultipleChoiceModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1564,7 +1570,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py
index 4690a0ad64ad8c..3e3a7103f07e30 100644
--- a/src/transformers/models/electra/modeling_flax_electra.py
+++ b/src/transformers/models/electra/modeling_flax_electra.py
@@ -22,13 +22,15 @@
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
-from jax.random import PRNGKey
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
@@ -184,13 +186,15 @@ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, dete
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
class FlaxElectraSelfAttention(nn.Module):
config: ElectraConfig
+ causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
- : {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -209,30 +213,113 @@ def setup(self):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states: Optional[jnp.array] = None,
+ init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
):
- head_dim = self.config.hidden_size // self.config.num_attention_heads
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.query(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.key(key_value_states)
+ value_states = self.value(key_value_states)
+ else:
+ # self_attention
+ key_states = self.key(hidden_states)
+ value_states = self.value(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
- query_states = self.query(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- value_states = self.value(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- key_states = self.key(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
@@ -292,10 +379,11 @@ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
class FlaxElectraAttention(nn.Module):
config: ElectraConfig
+ causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
- self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype)
+ self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
def __call__(
@@ -303,6 +391,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states=None,
+ init_cache=False,
deterministic=True,
output_attentions: bool = False,
):
@@ -313,6 +403,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ key_value_states=key_value_states,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -373,27 +465,46 @@ class FlaxElectraLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
- self.attention = FlaxElectraAttention(self.config, dtype=self.dtype)
+ self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
):
+ # Self Attention
attention_outputs = self.attention(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=encoder_hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
@@ -401,6 +512,8 @@ def __call__(
if output_attentions:
outputs += (attention_outputs[1],)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attention_outputs[1],)
return outputs
@@ -419,6 +532,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -426,13 +542,14 @@ def __call__(
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
- {head_mask.shape[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
@@ -443,6 +560,9 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -452,6 +572,9 @@ def __call__(
if output_attentions:
all_attentions += (layer_outputs[1],)
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -460,8 +583,11 @@ def __call__(
if not return_dict:
return tuple(v for v in outputs if v is not None)
- return FlaxBaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
)
@@ -478,6 +604,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -487,6 +616,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -548,6 +680,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
@@ -559,9 +692,26 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
- random_params = self.module.init(
- rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
- )["params"]
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
+ )
+
+ random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
@@ -573,6 +723,26 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
else:
return random_params
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
@@ -581,12 +751,15 @@ def __call__(
token_type_ids=None,
position_ids=None,
head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
params: dict = None,
- dropout_rng: PRNGKey = None,
+ dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ past_key_values: dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -613,19 +786,60 @@ def __call__(
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
- return self.module.apply(
- {"params": params or self.params},
- jnp.array(input_ids, dtype="i4"),
- jnp.array(attention_mask, dtype="i4"),
- jnp.array(token_type_ids, dtype="i4"),
- jnp.array(position_ids, dtype="i4"),
- jnp.array(head_mask, dtype="i4"),
- not train,
- output_attentions,
- output_hidden_states,
- return_dict,
- rngs=rngs,
- )
+ inputs = {"params": params or self.params}
+
+ if self.config.add_cross_attention:
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxElectraAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ else:
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ )
+
+ return outputs
class FlaxElectraModule(nn.Module):
@@ -645,6 +859,9 @@ def __call__(
token_type_ids,
position_ids,
head_mask: Optional[np.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -661,6 +878,9 @@ def __call__(
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@@ -1232,3 +1452,111 @@ class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
FlaxSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
+
+
+class FlaxElectraForCausalLMModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
+ self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
+ else:
+ self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask: Optional[jnp.ndarray] = None,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ prediction_scores = self.generator_predictions(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
+ else:
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ if not return_dict:
+ return (prediction_scores,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
+ autoregressive tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
+class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxElectraForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py
index 6483988a30e422..57f17c8a97476a 100644
--- a/src/transformers/models/electra/modeling_tf_electra.py
+++ b/src/transformers/models/electra/modeling_tf_electra.py
@@ -344,8 +344,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/electra/tokenization_electra.py b/src/transformers/models/electra/tokenization_electra.py
index 9fd5568cde31d2..2feeaaa2a7485a 100644
--- a/src/transformers/models/electra/tokenization_electra.py
+++ b/src/transformers/models/electra/tokenization_electra.py
@@ -20,12 +20,22 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/electra-small-generator": "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt",
+ "google/electra-small-generator": (
+ "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
+ ),
"google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
- "google/electra-large-generator": "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt",
- "google/electra-small-discriminator": "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt",
- "google/electra-base-discriminator": "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
- "google/electra-large-discriminator": "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt",
+ "google/electra-large-generator": (
+ "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
+ ),
+ "google/electra-small-discriminator": (
+ "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
+ ),
+ "google/electra-base-discriminator": (
+ "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
+ ),
+ "google/electra-large-discriminator": (
+ "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/electra/tokenization_electra_fast.py b/src/transformers/models/electra/tokenization_electra_fast.py
index 48a28cc98b9dd0..c37163672c81d3 100644
--- a/src/transformers/models/electra/tokenization_electra_fast.py
+++ b/src/transformers/models/electra/tokenization_electra_fast.py
@@ -21,20 +21,42 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/electra-small-generator": "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt",
+ "google/electra-small-generator": (
+ "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
+ ),
"google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
- "google/electra-large-generator": "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt",
- "google/electra-small-discriminator": "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt",
- "google/electra-base-discriminator": "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
- "google/electra-large-discriminator": "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt",
+ "google/electra-large-generator": (
+ "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
+ ),
+ "google/electra-small-discriminator": (
+ "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
+ ),
+ "google/electra-base-discriminator": (
+ "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
+ ),
+ "google/electra-large-discriminator": (
+ "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "google/electra-small-generator": "https://huggingface.co/google/electra-small-generator/resolve/main/tokenizer.json",
- "google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/tokenizer.json",
- "google/electra-large-generator": "https://huggingface.co/google/electra-large-generator/resolve/main/tokenizer.json",
- "google/electra-small-discriminator": "https://huggingface.co/google/electra-small-discriminator/resolve/main/tokenizer.json",
- "google/electra-base-discriminator": "https://huggingface.co/google/electra-base-discriminator/resolve/main/tokenizer.json",
- "google/electra-large-discriminator": "https://huggingface.co/google/electra-large-discriminator/resolve/main/tokenizer.json",
+ "google/electra-small-generator": (
+ "https://huggingface.co/google/electra-small-generator/resolve/main/tokenizer.json"
+ ),
+ "google/electra-base-generator": (
+ "https://huggingface.co/google/electra-base-generator/resolve/main/tokenizer.json"
+ ),
+ "google/electra-large-generator": (
+ "https://huggingface.co/google/electra-large-generator/resolve/main/tokenizer.json"
+ ),
+ "google/electra-small-discriminator": (
+ "https://huggingface.co/google/electra-small-discriminator/resolve/main/tokenizer.json"
+ ),
+ "google/electra-base-discriminator": (
+ "https://huggingface.co/google/electra-base-discriminator/resolve/main/tokenizer.json"
+ ),
+ "google/electra-large-discriminator": (
+ "https://huggingface.co/google/electra-large-discriminator/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/encoder_decoder/__init__.py b/src/transformers/models/encoder_decoder/__init__.py
index 2e1e7e1eb60576..759b49f50d3363 100644
--- a/src/transformers/models/encoder_decoder/__init__.py
+++ b/src/transformers/models/encoder_decoder/__init__.py
@@ -18,32 +18,66 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_encoder_decoder": ["EncoderDecoderConfig"],
-}
+_import_structure = {"configuration_encoder_decoder": ["EncoderDecoderConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_encoder_decoder"] = ["TFEncoderDecoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_encoder_decoder import EncoderDecoderConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_encoder_decoder import EncoderDecoderModel
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_encoder_decoder import TFEncoderDecoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
else:
diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
index 972b80db7b4dbe..a7ff6a7e3aa95f 100644
--- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
+++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
@@ -15,7 +15,7 @@
""" Classes to support Encoder-Decoder architectures"""
import warnings
-from typing import Optional
+from typing import Optional, Tuple, Union
import torch
from torch import nn
@@ -35,10 +35,10 @@
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
- "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
- "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
- "a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no "
- "need to pass them yourself anymore."
+ "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r"""
@@ -189,10 +189,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -213,11 +213,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -401,10 +403,9 @@ def from_encoder_decoder_pretrained(
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -430,21 +431,21 @@ def from_encoder_decoder_pretrained(
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- decoder_input_ids=None,
- decoder_attention_mask=None,
- encoder_outputs=None,
- past_key_values=None,
- inputs_embeds=None,
- decoder_inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
**kwargs,
- ):
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
Returns:
@@ -457,7 +458,7 @@ def forward(
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
... "bert-base-uncased", "bert-base-uncased"
- >>> ) # initialize Bert2Bert from pre-trained checkpoints
+ ... ) # initialize Bert2Bert from pre-trained checkpoints
>>> # training
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
@@ -572,8 +573,9 @@ def prepare_inputs_for_generation(
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. "
- "Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
index 7ffc81687d8e36..36df84f3055341 100644
--- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
+++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
@@ -330,10 +330,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -354,7 +354,8 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
- f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+ f" and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
@@ -593,7 +594,7 @@ def _decoder_forward(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
- encoder_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
**kwargs,
)
@@ -689,7 +690,8 @@ def __call__(
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
- "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+ " be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
@@ -869,10 +871,9 @@ def from_encoder_decoder_pretrained(
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
index 9e92e767b1b857..5c74e8433e6d6e 100644
--- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
@@ -43,10 +43,10 @@
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
- "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
- "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
- "a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
- "need to pass them yourself anymore."
+ "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r"""
@@ -211,10 +211,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -231,11 +231,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -319,10 +321,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pt = kwargs.pop("from_pt", False)
if from_pt:
raise ValueError(
- "Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. "
- "Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, "
- "create the encoder and decoder models separately, and use them to initialize `TFEncoderDecoderModel`. "
- "Check `TFEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
+ "Initializing `TFEncoderDecoderModel` from a pytorch checkpoint is not supported currently. Use a"
+ " tensorflow checkpoint instead. If only the pytorch checkpoints are available, create the encoder and"
+ " decoder models separately, and use them to initialize `TFEncoderDecoderModel`. Check"
+ " `TFEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@@ -450,10 +452,9 @@ def from_encoder_decoder_pretrained(
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -528,7 +529,7 @@ def call(
>>> # forward
>>> input_ids = tokenizer.encode(
... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf"
- >>> ) # Batch size 1
+ ... ) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
@@ -702,8 +703,9 @@ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
- "Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/flaubert/__init__.py b/src/transformers/models/flaubert/__init__.py
index fa4e31eeb6c1ea..95741cab2ebd0c 100644
--- a/src/transformers/models/flaubert/__init__.py
+++ b/src/transformers/models/flaubert/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_flaubert": ["FlaubertTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flaubert"] = [
"FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaubertForMultipleChoice",
@@ -38,7 +43,12 @@
"FlaubertWithLMHeadModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_flaubert"] = [
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFFlaubertForMultipleChoice",
@@ -55,7 +65,12 @@
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig
from .tokenization_flaubert import FlaubertTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flaubert import (
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaubertForMultipleChoice,
@@ -67,7 +82,12 @@
FlaubertWithLMHeadModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py
index d4bd3f53fdb339..bc49216221e2b7 100644
--- a/src/transformers/models/flaubert/modeling_tf_flaubert.py
+++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py
@@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask
else:
# assert lengths.max().item() <= slen
- alen = tf.range(slen)
- mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
+ alen = tf.range(slen, dtype=lengths.dtype)
+ mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:
diff --git a/src/transformers/models/flaubert/tokenization_flaubert.py b/src/transformers/models/flaubert/tokenization_flaubert.py
index 828525d756afd8..4fbb3783d8a38b 100644
--- a/src/transformers/models/flaubert/tokenization_flaubert.py
+++ b/src/transformers/models/flaubert/tokenization_flaubert.py
@@ -32,16 +32,28 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "flaubert/flaubert_small_cased": "https://huggingface.co/flaubert/flaubert_small_cased/resolve/main/vocab.json",
- "flaubert/flaubert_base_uncased": "https://huggingface.co/flaubert/flaubert_base_uncased/resolve/main/vocab.json",
+ "flaubert/flaubert_small_cased": (
+ "https://huggingface.co/flaubert/flaubert_small_cased/resolve/main/vocab.json"
+ ),
+ "flaubert/flaubert_base_uncased": (
+ "https://huggingface.co/flaubert/flaubert_base_uncased/resolve/main/vocab.json"
+ ),
"flaubert/flaubert_base_cased": "https://huggingface.co/flaubert/flaubert_base_cased/resolve/main/vocab.json",
- "flaubert/flaubert_large_cased": "https://huggingface.co/flaubert/flaubert_large_cased/resolve/main/vocab.json",
+ "flaubert/flaubert_large_cased": (
+ "https://huggingface.co/flaubert/flaubert_large_cased/resolve/main/vocab.json"
+ ),
},
"merges_file": {
- "flaubert/flaubert_small_cased": "https://huggingface.co/flaubert/flaubert_small_cased/resolve/main/merges.txt",
- "flaubert/flaubert_base_uncased": "https://huggingface.co/flaubert/flaubert_base_uncased/resolve/main/merges.txt",
+ "flaubert/flaubert_small_cased": (
+ "https://huggingface.co/flaubert/flaubert_small_cased/resolve/main/merges.txt"
+ ),
+ "flaubert/flaubert_base_uncased": (
+ "https://huggingface.co/flaubert/flaubert_base_uncased/resolve/main/merges.txt"
+ ),
"flaubert/flaubert_base_cased": "https://huggingface.co/flaubert/flaubert_base_cased/resolve/main/merges.txt",
- "flaubert/flaubert_large_cased": "https://huggingface.co/flaubert/flaubert_large_cased/resolve/main/merges.txt",
+ "flaubert/flaubert_large_cased": (
+ "https://huggingface.co/flaubert/flaubert_large_cased/resolve/main/merges.txt"
+ ),
},
}
@@ -130,7 +142,8 @@ def _tokenize(self, text, bypass_tokenizer=False):
lang = "fr"
if lang and self.lang2id and lang not in self.lang2id:
logger.error(
- "Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
+ "Supplied language code not found in lang2id mapping. Please check that your language is supported by"
+ " the loaded pretrained model."
)
if bypass_tokenizer:
diff --git a/src/transformers/models/flava/__init__.py b/src/transformers/models/flava/__init__.py
new file mode 100644
index 00000000000000..29d8240032a434
--- /dev/null
+++ b/src/transformers/models/flava/__init__.py
@@ -0,0 +1,99 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_flava": [
+ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "FlavaConfig",
+ "FlavaImageCodebookConfig",
+ "FlavaImageConfig",
+ "FlavaMultimodalConfig",
+ "FlavaTextConfig",
+ ],
+}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_flava"] = ["FlavaFeatureExtractor"]
+ _import_structure["processing_flava"] = ["FlavaProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flava"] = [
+ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "FlavaForPreTraining",
+ "FlavaImageCodebook",
+ "FlavaImageModel",
+ "FlavaModel",
+ "FlavaMultimodalModel",
+ "FlavaPreTrainedModel",
+ "FlavaTextModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_flava import (
+ FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_flava import FlavaFeatureExtractor
+ from .processing_flava import FlavaProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flava import (
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaPreTrainedModel,
+ FlavaTextModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/flava/configuration_flava.py b/src/transformers/models/flava/configuration_flava.py
new file mode 100644
index 00000000000000..c42c90086406b4
--- /dev/null
+++ b/src/transformers/models/flava/configuration_flava.py
@@ -0,0 +1,646 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" FLAVA model configurations"""
+
+import copy
+import os
+from typing import Any, Dict, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/flava-full": "https://huggingface.co/facebook/flava-full/resolve/main/config.json",
+}
+
+
+class FlavaImageConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
+ FLAVA model according to the specified arguments, defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ mask_token (`bool`, *optional*, defaults to `True`):
+ Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA.
+ vocab_size (`int`, *optional*, defaults to 8192):
+ Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked
+ Image Modeling) loss for FLAVA.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlavaImageModel, FlavaImageConfig
+
+ >>> # Initializing a FlavaImageModel with style configuration
+ >>> configuration = FlavaImageConfig()
+
+ >>> # Initializing a FlavaImageModel model from the style configuration
+ >>> model = FlavaImageModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "flava_image_model"
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 12,
+ intermediate_size: int = 3072,
+ hidden_act: int = "gelu",
+ hidden_dropout_prob: float = 0.0,
+ attention_probs_dropout_prob: float = 0.0,
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-12,
+ image_size: int = 224,
+ patch_size: int = 16,
+ num_channels: int = 3,
+ qkv_bias: bool = True,
+ mask_token: bool = True,
+ vocab_size: int = 8192,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.mask_token = mask_token
+ self.vocab_size = vocab_size
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the image config dict if we are loading from FlavaConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["image_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FlavaTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
+ FLAVA model according to the specified arguments, defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FlavaTextModel`].
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though
+ text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is
+ used similar to RoBERTa.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlavaTextModel, FlavaTextConfig
+
+ >>> # Initializing a FlavaTextModel with style configuration
+ >>> configuration = FlavaTextConfig()
+
+ >>> # Initializing a FlavaTextConfig from the style configuration
+ >>> model = FlavaTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "flava_text_model"
+
+ def __init__(
+ self,
+ vocab_size: int = 30522,
+ type_vocab_size: int = 2,
+ max_position_embeddings: int = 512,
+ position_embedding_type: str = "absolute",
+ hidden_size: int = 768,
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 12,
+ intermediate_size: int = 3072,
+ hidden_act: str = "gelu",
+ hidden_dropout_prob: float = 0.0,
+ attention_probs_dropout_prob: float = 0.0,
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-12,
+ pad_token_id: int = 0,
+ qkv_bias: bool = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.position_embedding_type = position_embedding_type
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.pad_token_id = pad_token_id
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from FlavaConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["text_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FlavaMultimodalConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
+ an FLAVA model according to the specified arguments, defining the model architecture.
+
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
+ [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ use_cls_token (`bool`, *optional*, defaults to `True`):
+ Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import FlavaMultimodalModel, FlavaMultimodalConfig
+
+ >>> # Initializing a FlavaMultimodalModel with style configuration
+ >>> configuration = FlavaMultimodalConfig()
+
+ >>> # Initializing a FlavaMultimodalModel model from the style configuration
+ >>> model = FlavaMultimodalModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "flava_multimodal_model"
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ num_hidden_layers: int = 6,
+ num_attention_heads: int = 12,
+ intermediate_size: int = 3072,
+ hidden_act: int = "gelu",
+ hidden_dropout_prob: int = 0.0,
+ attention_probs_dropout_prob: int = 0.0,
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-12,
+ qkv_bias: bool = True,
+ use_cls_token: bool = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.use_cls_token = use_cls_token
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the multimodal config dict if we are loading from FlavaConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["multimodal_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FlavaImageCodebookConfig(PretrainedConfig):
+ model_type = "flava_image_codebook"
+
+ r"""
+ [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
+ is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
+ [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_groups (`int`, defaults to 4):
+ Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
+ internal calculation and estimations.
+ input_channels (`int`, defaults to 3):
+ Number of channels in the image to be passed.
+ num_blocks_per_group (`int`, defaults to 2):
+ Number of conv-based blocks per group.
+ hidden_size (`int`, defaults to 256):
+ Size of hidden dim for the blocks.
+ vocab_size (`int`, defaults to 8192):
+ Size of the output vocabulary for the codebook.
+ freeze (`bool`, defaults to `True`):
+ Whether to freeze the weights of the model.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlavaImageCodebook, FlavaImageCodebookConfig
+
+ >>> # Initializing a FlavaImageCodebook with style configuration
+ >>> configuration = FlavaImageCodebookConfig()
+
+ >>> # Initializing a FlavaImageCodebook model from the style configuration
+ >>> model = FlavaImageCodebook(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ def __init__(
+ self,
+ num_groups: int = 4,
+ input_channels: int = 3,
+ num_blocks_per_group: int = 2,
+ hidden_size: int = 256,
+ vocab_size: int = 8192,
+ freeze: int = True,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.num_groups = num_groups
+ self.input_channels = input_channels
+ self.num_blocks_per_group = num_blocks_per_group
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.freeze = freeze
+ self.initializer_range = initializer_range
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the image codebook config dict if we are loading from FlavaConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["image_codebook_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FlavaConfig(PretrainedConfig):
+ r"""
+ [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
+ instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
+ and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
+ that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FlavaTextConfig`].
+ image_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FlavaImageConfig`].
+ multimodal_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ projection_dim (`int`, *optional*, defaults to 512):
+ Dimentionality of text and image projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The inital value of the *logit_scale* paramter. Default is used as per the original FLAVA/CLIP
+ implementation.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ ce_ignore_index (`int`, *optional*, defaults to -100):
+ Cross entropy index to ignore.
+ mim_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MIM (Masked Image Modeling) unimodal loss
+ mlm_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MLM (Masked Language Modeling) unimodal loss
+ global_contrastive_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to global contrastive cross-alignment loss.
+ itm_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to image-text matching multimodal loss.
+ mmm_image_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MMM loss's image part.
+ mmm_text_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MMM loss's text part.
+ global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
+ Whether to use global backpropgation through all workers in contrastive loss.
+ skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
+ Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses.
+ return_loss (`bool`, *optional*, defaults to `True`):
+ Whether to return loss or not
+
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlavaModel, FlavaForPreTraining, FlavaConfig
+
+ >>> # Initializing a FlavaConfig with style configuration
+ >>> configuration = FlavaConfig()
+
+ >>> # Initializing a FlavaModel and FlavaForPreTraining model from the style configuration
+ >>> model = FlavaModel(configuration)
+ >>> model_pre = FlavaForPreTraining(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> configuration_pre = model_pre.config
+ ```
+ """
+
+ model_type = "flava"
+ is_composition = True
+
+ def __init__(
+ self,
+ image_config_dict: Dict[str, Any] = None,
+ text_config_dict: Dict[str, Any] = None,
+ multimodal_config_dict: Dict[str, Any] = None,
+ image_codebook_config_dict: Dict[str, Any] = None,
+ hidden_size: int = 768,
+ layer_norm_eps: float = 1e-12,
+ projection_dim: int = 768,
+ init_codebook: bool = True,
+ logit_scale_init_value: float = 2.6592,
+ initializer_range: float = 0.02,
+ ce_ignore_index: int = -100,
+ mim_weight: float = 1.0,
+ mlm_weight: float = 1.0,
+ global_contrastive_weight: float = 1.0,
+ itm_weight: float = 1.0,
+ mmm_image_weight: float = 1.0,
+ mmm_text_weight: float = 1.0,
+ global_backprop_contrastive: bool = True,
+ skip_unmasked_multimodal_encoder: bool = True,
+ return_loss: bool = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ if image_config_dict is None:
+ image_config_dict = {}
+ logger.info("image_config_dict is None. initializing the FlavaImageConfig with default values.")
+
+ if text_config_dict is None:
+ text_config_dict = {}
+ logger.info("text_config_dict is None. Initializing the FlavaTextConfig with default values.")
+
+ if multimodal_config_dict is None:
+ multimodal_config_dict = {}
+ logger.info("multimodal_config_dict is None. initializing the FlavaMultimodalConfig with default values.")
+
+ if image_codebook_config_dict is None:
+ image_codebook_config_dict = {}
+ logger.info(
+ "image_codebook_config_dict is None. initializing the FlavaImageCodebookConfig with default values."
+ )
+
+ self.image_config_dict = image_config_dict
+ self.text_config_dict = text_config_dict
+ self.multimodal_config_dict = multimodal_config_dict
+ self.image_codebook_config_dict = image_codebook_config_dict
+
+ self.image_config = FlavaImageConfig(**self.image_config_dict)
+ self.text_config = FlavaTextConfig(**self.text_config_dict)
+ self.multimodal_config = FlavaMultimodalConfig(**self.multimodal_config_dict)
+ self.image_codebook_config = FlavaImageCodebookConfig(**self.image_codebook_config_dict)
+ self.projection_dim = projection_dim
+ self.init_codebook = init_codebook
+
+ self.hidden_size = hidden_size
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.logit_scale_init_value = logit_scale_init_value
+ self.initializer_factor = 1.0
+ self.ce_ignore_index = ce_ignore_index
+ self.mim_weight = mim_weight
+ self.mlm_weight = mlm_weight
+ self.global_contrastive_weight = global_contrastive_weight
+ self.itm_weight = itm_weight
+ self.mmm_image_weight = mmm_image_weight
+ self.mmm_text_weight = mmm_text_weight
+ self.global_backprop_contrastive = global_backprop_contrastive
+ self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder
+ self.return_loss = return_loss
+
+ @classmethod
+ def from_configs(
+ cls,
+ image_config: FlavaImageConfig,
+ text_config: FlavaTextConfig,
+ multimodal_config: FlavaMultimodalConfig,
+ image_codebook_config: FlavaImageCodebookConfig,
+ **kwargs
+ ):
+ r"""
+ Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
+ configuration, flava multimodal model and flava codebook model configuration.
+
+ Returns:
+ [`FlavaConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ image_config_dict=image_config.to_dict(),
+ text_config_dict=text_config.to_dict(),
+ multimodal_config_dict=multimodal_config.to_dict(),
+ image_codebook_config_dict=image_codebook_config.to_dict(),
+ **kwargs,
+ )
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["image_config"] = self.image_config.to_dict()
+ output["text_config"] = self.text_config.to_dict()
+ output["multimodal_config"] = self.multimodal_config.to_dict()
+ output["image_codebook_config"] = self.image_codebook_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/flava/convert_dalle_to_flava_codebook.py b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py
new file mode 100644
index 00000000000000..7b544125114c85
--- /dev/null
+++ b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py
@@ -0,0 +1,102 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import torch
+
+from transformers import FlavaImageCodebook, FlavaImageCodebookConfig
+
+
+def rreplace(s, old, new, occurrence):
+ li = s.rsplit(old, occurrence)
+ return new.join(li)
+
+
+def count_parameters(state_dict):
+ # encoder.embeddings are double copied in original FLAVA
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
+
+
+def upgrade_state_dict(state_dict):
+ upgrade = {}
+
+ group_keys = ["group_1", "group_2", "group_3", "group_4"]
+ for key, value in state_dict.items():
+ for group_key in group_keys:
+ if group_key in key:
+ key = key.replace(f"{group_key}.", f"{group_key}.group.")
+
+ if "res_path" in key:
+ key = key.replace("res_path.", "res_path.path.")
+
+ if key.endswith(".w"):
+ key = rreplace(key, ".w", ".weight", 1)
+ if key.endswith(".b"):
+ key = rreplace(key, ".b", ".bias", 1)
+
+ upgrade[key] = value.float()
+
+ return upgrade
+
+
+@torch.no_grad()
+def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ from dall_e import Encoder
+
+ encoder = Encoder()
+ if os.path.exists(checkpoint_path):
+ ckpt = torch.load(checkpoint_path)
+ else:
+ ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)
+
+ if isinstance(ckpt, Encoder):
+ ckpt = ckpt.state_dict()
+ encoder.load_state_dict(ckpt)
+
+ if config_path is not None:
+ config = FlavaImageCodebookConfig.from_pretrained(config_path)
+ else:
+ config = FlavaImageCodebookConfig()
+
+ hf_model = FlavaImageCodebook(config).eval()
+ state_dict = encoder.state_dict()
+
+ hf_state_dict = upgrade_state_dict(state_dict)
+ hf_model.load_state_dict(hf_state_dict)
+ hf_state_dict = hf_model.state_dict()
+ hf_count = count_parameters(hf_state_dict)
+ state_dict_count = count_parameters(state_dict)
+
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
+
+ if save_checkpoint:
+ hf_model.save_pretrained(pytorch_dump_folder_path)
+ else:
+ return hf_state_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ args = parser.parse_args()
+
+ convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
diff --git a/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
new file mode 100644
index 00000000000000..95ebb2bfdb2360
--- /dev/null
+++ b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
@@ -0,0 +1,99 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import torch
+
+from transformers import FlavaConfig, FlavaForPreTraining
+from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint
+
+
+def count_parameters(state_dict):
+ # encoder.embeddings are double copied in original FLAVA
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
+
+
+def upgrade_state_dict(state_dict, codebook_state_dict):
+ upgrade = {}
+
+ for key, value in state_dict.items():
+ if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
+ continue
+
+ key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
+ key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
+ key = key.replace("heads.cmd.itm_head.cls", "itm_head")
+ key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
+ key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
+ key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
+ key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
+ key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
+ key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
+ key = key.replace("image_encoder.module", "flava.image_model")
+ key = key.replace("text_encoder.module", "flava.text_model")
+ key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
+ key = key.replace("mm_encoder.module", "flava.multimodal_model")
+ key = key.replace("text_projection", "flava.text_projection")
+ key = key.replace("image_projection", "flava.image_projection")
+
+ upgrade[key] = value.float()
+
+ for key, value in codebook_state_dict.items():
+ upgrade[f"image_codebook.{key}"] = value
+
+ return upgrade
+
+
+@torch.no_grad()
+def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = FlavaConfig.from_pretrained(config_path)
+ else:
+ config = FlavaConfig()
+
+ hf_model = FlavaForPreTraining(config).eval()
+
+ codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False)
+
+ if os.path.exists(checkpoint_path):
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ else:
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")
+
+ hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict)
+ hf_model.load_state_dict(hf_state_dict)
+ hf_state_dict = hf_model.state_dict()
+ hf_count = count_parameters(hf_state_dict)
+ state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict)
+
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
+
+ hf_model.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
+ parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ args = parser.parse_args()
+
+ convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path)
diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py
new file mode 100644
index 00000000000000..c3aba8c70b6ce9
--- /dev/null
+++ b/src/transformers/models/flava/feature_extraction_flava.py
@@ -0,0 +1,351 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for FLAVA."""
+
+import math
+import random
+from functools import lru_cache
+from typing import Any, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+# These values are taken from CLIP
+FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
+FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
+FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
+FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
+LOGIT_LAPLACE_EPS: float = 0.1
+
+
+# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
+class FlavaMaskingGenerator:
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]] = 14,
+ total_mask_patches: int = 75,
+ mask_group_max_patches: Optional[int] = None,
+ mask_group_min_patches: int = 16,
+ mask_group_min_aspect_ratio: Optional[float] = 0.3,
+ mask_group_max_aspect_ratio: float = None,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.height, self.width = input_size
+
+ self.num_patches = self.height * self.width
+ self.total_mask_patches = total_mask_patches
+
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
+
+ mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
+ self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
+
+ def __repr__(self):
+ repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
+ self.height,
+ self.width,
+ self.mask_group_min_patches,
+ self.mask_group_max_patches,
+ self.total_mask_patches,
+ self.log_aspect_ratio[0],
+ self.log_aspect_ratio[1],
+ )
+ return repr_str
+
+ def get_shape(self):
+ return self.height, self.width
+
+ def _mask(self, mask, max_mask_patches):
+ delta = 0
+ for _attempt in range(10):
+ target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ height = int(round(math.sqrt(target_area * aspect_ratio)))
+ width = int(round(math.sqrt(target_area / aspect_ratio)))
+ if width < self.width and height < self.height:
+ top = random.randint(0, self.height - height)
+ left = random.randint(0, self.width - width)
+
+ num_masked = mask[top : top + height, left : left + width].sum()
+ # Overlap
+ if 0 < height * width - num_masked <= max_mask_patches:
+ for i in range(top, top + height):
+ for j in range(left, left + width):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ def __call__(self):
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
+ mask_count = 0
+ while mask_count < self.total_mask_patches:
+ max_mask_patches = self.total_mask_patches - mask_count
+ max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
+
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return mask
+
+
+class FlavaFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a FLAVA feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int`, *optional*, defaults to 224):
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`int`, *optional*, defaults to 224):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with `image_mean` and `image_std`.
+ image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ input_size_patches (`int`, *optional*, defaults to 14):
+ Number of patches in the image in height and width direction. 14x14 = 196 total patches.
+ total_mask_patches (`int`, *optional*, defaults to 75):
+ Total number of patches that should be masked.
+ mask_group_min_patches (`int`, *optional*, defaults to 16):
+ Minimum number of patches that should be masked.
+ mask_group_max_patches (`int`, *optional*, defaults to None):
+ Maximum number of patches that should be masked.
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
+ Minimum aspect ratio of the mask window.
+ mask_group_max_aspect_ratio (`float`, *optional*, defaults to None):
+ Maximum aspect ratio of the mask window
+ codebook_do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input for codebook to a certain `codebook_size`.
+ codebook_size (`int`, *optional*, defaults to 224):
+ Resize the input for codebook to the given size. Only has an effect if `codebook_do_resize` is set to
+ `True`.
+ codebook_resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input for codebook at the center. If the input size is smaller than
+ `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped.
+ codebook_crop_size (`int`, *optional*, defaults to 224):
+ Desired output size for codebook input when applying center-cropping. Only has an effect if
+ `codebook_do_center_crop` is set to `True`.
+ codebook_do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`.
+ codebook_image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0, 0, 0]`):
+ The sequence of means for each channel, to be used when normalizing images for codebook.
+ codebook_image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images for codebook.
+
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Union[int, Tuple[int, int]] = 224,
+ resample: int = Image.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Union[int, Tuple[int, int]] = 224,
+ do_normalize: bool = True,
+ image_mean: Tuple[float, float, float] = FLAVA_IMAGE_MEAN,
+ image_std: Tuple[float, float, float] = FLAVA_IMAGE_STD,
+ # Mask related params
+ input_size_patches: int = 14,
+ total_mask_patches: int = 75,
+ mask_group_min_patches: int = 16,
+ mask_group_max_patches: Optional[int] = None,
+ mask_group_min_aspect_ratio: float = 0.3,
+ mask_group_max_aspect_ratio: Optional[float] = None,
+ # Codebook related params
+ codebook_do_resize: bool = True,
+ codebook_size: bool = 112,
+ codebook_resample: int = Image.LANCZOS,
+ codebook_do_center_crop: bool = True,
+ codebook_crop_size: int = 112,
+ codebook_do_map_pixels: bool = True,
+ codebook_do_normalize: bool = True,
+ codebook_image_mean: Tuple[float, float, float] = FLAVA_CODEBOOK_MEAN,
+ codebook_image_std: Tuple[float, float, float] = FLAVA_CODEBOOK_STD,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ self.codebook_do_resize = codebook_do_resize
+ self.codebook_size = codebook_size
+ self.codebook_resample = codebook_resample
+ self.codebook_do_center_crop = codebook_do_center_crop
+ self.codebook_crop_size = codebook_crop_size
+ self.codebook_do_map_pixels = codebook_do_map_pixels
+ self.codebook_do_normalize = codebook_do_normalize
+ self.codebook_image_mean = codebook_image_mean
+ self.codebook_image_std = codebook_image_std
+
+ @property
+ @lru_cache()
+ def masking_generator(self):
+ return FlavaMaskingGenerator(
+ input_size=self.input_size_patches,
+ total_mask_patches=self.total_mask_patches,
+ mask_group_min_patches=self.mask_group_min_patches,
+ mask_group_max_patches=self.mask_group_max_patches,
+ mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio,
+ mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio,
+ )
+
+ def map_pixels(self, x):
+ return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
+
+ def __call__(
+ self,
+ images: Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+ ],
+ return_image_mask: Optional[bool] = None,
+ return_codebook_pixels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs: Any
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_image_mask (`bool`, *optional*, defaults to None):
+ If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version.
+
+ return_codebook_pixels (`bool`, *optional*, defaults to None):
+ If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the
+ default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ """
+ # Input type checking for clearer error
+ if isinstance(images, (list, tuple)) and len(images) != 0:
+ self._ensure_format_supported(images[0])
+ else:
+ self._ensure_format_supported(images)
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ images_for_codebook = images
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None and self.resample is not None:
+ images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [self.center_crop(image, self.crop_size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+ # return as BatchFeature
+ data = {"pixel_values": images}
+
+ if return_codebook_pixels:
+ images = images_for_codebook
+ if self.codebook_do_resize and self.codebook_size is not None and self.codebook_resample is not None:
+ images = [
+ self.resize(image=image, size=self.codebook_size, resample=self.codebook_resample)
+ for image in images
+ ]
+ if self.codebook_do_center_crop and self.codebook_crop_size is not None:
+ images = [self.center_crop(image, self.codebook_crop_size) for image in images]
+ if self.codebook_do_normalize:
+ images = [
+ self.normalize(image=image, mean=self.codebook_image_mean, std=self.codebook_image_std)
+ for image in images
+ ]
+ if self.codebook_do_map_pixels:
+ images = [self.map_pixels(image) for image in images]
+
+ data["codebook_pixel_values"] = images
+
+ if return_image_mask:
+ masks = [self.masking_generator() for _ in images]
+ data["bool_masked_pos"] = masks
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py
new file mode 100644
index 00000000000000..c0841a0e277230
--- /dev/null
+++ b/src/transformers/models/flava/modeling_flava.py
@@ -0,0 +1,2099 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch FLAVA model."""
+
+import collections
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+
+from transformers.utils.doc import add_code_sample_docstrings
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_flava import (
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/flava-full"
+
+# Codebook docstring
+_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
+_FEAT_EXTRACTOR_FOR_DOC = "FlavaFeatureExtractor"
+_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
+_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
+_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
+
+FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/flava-full",
+ # See all flava models at https://huggingface.co/models?filter=flava
+]
+FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/flava-image-codebook"]
+LOGIT_SCALE_CLAMP_MIN = 0
+LOGIT_SCALE_CLAMP_MAX = 4.6052
+
+FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]
+
+
+@dataclass
+class FlavaModelOutput(ModelOutput):
+ """
+ Output from FlavaModel containing embeddings and outputs from individual encoders.
+
+ Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
+
+ Args:
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`].
+ image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FlavaImageModel`].
+ text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
+ text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
+ The output of the [`FlavaTextModel`].
+ multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
+ multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
+ The output of the [`FlavaMultimodalModel`].
+ """
+
+ image_embeddings: Optional[torch.FloatTensor] = None
+ image_output: Optional[BaseModelOutputWithPooling] = None
+ text_embeddings: Optional[torch.FloatTensor] = None
+ text_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+@dataclass
+class FlavaLosses(ModelOutput):
+ """Class representing pretraining losses from FLAVA model
+
+ Args:
+ mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
+ Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
+ mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
+ Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
+ itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
+ Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
+ masked pairs in FLAVA.
+ global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
+ Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
+ data. This is calculated on unmasked images and texts.
+ mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
+ Masked Multimodal Modeling loss's image component calculated on paired image-text data.
+ mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
+ Masked Multimodal Modeling loss's text component calculated on paired image-text data.
+ """
+
+ mim: Optional[torch.FloatTensor] = None
+ mlm: Optional[torch.FloatTensor] = None
+ itm: Optional[torch.FloatTensor] = None
+ global_contrastive: Optional[torch.FloatTensor] = None
+ mmm_image: Optional[torch.FloatTensor] = None
+ mmm_text: Optional[torch.FloatTensor] = None
+
+ def all_none(self) -> bool:
+ all_none = True
+ for v in self.values():
+ if v is not None:
+ all_none = False
+ break
+ return all_none
+
+
+@dataclass
+class FlavaForPreTrainingOutput(ModelOutput):
+ """
+ Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
+
+ Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
+
+ Args:
+ loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
+ Total loss calculated for this model.
+ loss_info (`FlavaLosses`):
+ Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
+ the keys.
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`].
+ image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FlavaImageModel`].
+ text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
+ text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
+ The output of the [`FlavaTextModel`].
+ multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
+ multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
+ The output of the [`FlavaMultimodalModel`].
+
+ image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
+ to create masked images.
+ image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
+ text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
+ The text embeddings which are basically the pooled output of [`FlavaTextModel`].
+ text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
+ The output of the [`FlavaTextModel`].
+ multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
+ The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
+ multimodal_masked_output (`BaseModelOutputWithPooling`, returned when `input_ids_masked` and `pixel_values` are present):
+ The output of the [`FlavaMultimodalModel`].
+
+ mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
+ The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
+ returned when `bool_masked_pos` has some of the patches masked.
+ mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
+ The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
+ the tokens masked.
+ itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
+ The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
+ mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
+ The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
+ output is returned when `bool_masked_pos` has some of the patches masked.
+ mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
+ The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
+ some of the tokens masked.
+ contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
+ `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
+ scores. This is calculated on unmasked images and texts.
+ contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
+ `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
+ texts.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_info: FlavaLosses = None
+ image_embeddings: Optional[torch.FloatTensor] = None
+ image_output: Optional[BaseModelOutputWithPooling] = None
+ text_embeddings: Optional[torch.FloatTensor] = None
+ text_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
+ image_masked_embeddings: Optional[torch.FloatTensor] = None
+ image_masked_output: Optional[BaseModelOutputWithPooling] = None
+ text_masked_embeddings: Optional[torch.FloatTensor] = None
+ text_masked_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
+ mim_logits: Optional[torch.FloatTensor] = None
+ mlm_logits: Optional[torch.FloatTensor] = None
+ itm_logits: Optional[torch.FloatTensor] = None
+ contrastive_logits_per_image: Optional[torch.FloatTensor] = None
+ contrastive_logits_per_text: Optional[torch.FloatTensor] = None
+ mmm_image_logits: Optional[torch.FloatTensor] = None
+ mmm_text_logits: Optional[torch.FloatTensor] = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ transformer_outputs = [
+ "text_output",
+ "image_output",
+ "multimodal_output",
+ "text_masked_output",
+ "image_masked_output",
+ "multimodal_masked_output",
+ ]
+ return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
+class FlavaImageEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ use_mask_token = use_mask_token or config.mask_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = PatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174
+ """
+
+ npatch = embeddings.shape[1] - 1
+ num_pos = self.position_embeddings.shape[1] - 1
+ if npatch == num_pos and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ num_h_patches = height // self.config.patch_size
+ num_w_patches = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2),
+ scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
+ raise ValueError(
+ f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
+ f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ batch_size, seq_len, _ = embeddings.size()
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # B X H X W = B X HW
+ if bool_masked_pos.dim() == 3:
+ bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
+class PatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ num_channels: int = 3,
+ embed_dim: int = 768,
+ ):
+ super().__init__()
+ if not isinstance(image_size, collections.abc.Iterable):
+ image_size = (image_size, image_size)
+ if not isinstance(patch_size, collections.abc.Iterable):
+ patch_size = (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+class FlavaTextEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ input_shape = input_ids.size()
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class FlavaSelfAttention(nn.Module):
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class FlavaSelfOutput(nn.Module):
+ """
+ The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
+ models), due to the layernorm applied before each block.
+ """
+
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class FlavaAttention(nn.Module):
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ self.attention = FlavaSelfAttention(config)
+ self.output = FlavaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(
+ hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class FlavaIntermediate(nn.Module):
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class FlavaOutput(nn.Module):
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput.forward
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+class FlavaLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: FlavaPossibleConfigs) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = FlavaAttention(config)
+ self.intermediate = FlavaIntermediate(config)
+ self.output = FlavaOutput(config)
+
+ # TODO: Check fp32 layer norm possiblity
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class FlavaEncoder(nn.Module):
+ def __init__(self, config: FlavaConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
+ )
+
+
+class FlavaPooler(nn.Module):
+ def __init__(self, config: FlavaPossibleConfigs):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+FLAVA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`{config}`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+FLAVA_INPUTS_DOCSTRING_COMMON = r"""
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`FlavaFeatureExtractor`]. See
+ [`FlavaFeatureExtractor.__call__`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ interpolate_pos_encoding (`bool`, *optional*):
+ Whether to interpolate the pre-trained position encodings.
+"""
+
+FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
+
+FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ [What are token type IDs?](../glossary#token-type-ids)
+"""
+
+FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
+
+FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
+ The concatenated hidden states of unimodal encoders.
+"""
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
+ Args:
+ skip_multimodal_encoder (*bool*, *optional*):
+ Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
+"""
+
+FLAVA_MODEL_INPUTS_DOCSTRING = (
+ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+ + FLAVA_MODEL_INPUTS_DOCSTRING_BASE
+)
+
+
+FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+ input_ids_masked (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
+ to be used with MLM. Indices can be obtained using [`BertTokenizer`] along with
+ [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
+
+"""
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ + r"""
+ image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*):
+ Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
+ in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ skip_unmasked_multimodal_encoder (*bool*, *optional*):
+ Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
+ multimodal embeddings or outputs as of now.
+
+ mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
+ Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
+ Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
+ indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
+ ..., text_config.vocab_size - 1]`.
+
+ mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
+ Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
+ image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
+ computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
+ generated automatically using the image codebook assigned to the model. By default, it uses
+ [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
+
+ itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
+ Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
+ The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
+
+ return_loss (`bool`, *optional*, default to None):
+ Whether to return calculated loss or not.
+"""
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
+ Parameters:
+ image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
+ be initialized using the image_codebook_config defined in the config first as the first parameter.
+"""
+
+
+class FlavaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = FlavaConfig
+ base_model_prefix = "flava"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None:
+ if isinstance(module, FlavaEncoder):
+ module.gradient_checkpointing = value
+
+
+@add_start_docstrings(
+ "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
+)
+class FlavaImageModel(FlavaPreTrainedModel):
+ config_class = FlavaImageConfig
+ # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
+ base_model_prefix = "flava.image_model"
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
+ super().__init__(config)
+
+ self.config = config
+
+ self.embeddings = FlavaImageEmbeddings(config)
+ self.encoder = FlavaEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.embeddings.patch_embeddings
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.embeddings.patch_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
+)
+class FlavaTextModel(FlavaPreTrainedModel):
+ config_class = FlavaTextConfig
+ # This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints.
+ base_model_prefix = "flava.text_model"
+
+ def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = FlavaTextEmbeddings(config)
+ self.encoder = FlavaEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def get_input_embeddings(self) -> PatchEmbeddings:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape, input_ids.device
+ )
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
+)
+class FlavaMultimodalModel(FlavaPreTrainedModel):
+ config_class = FlavaMultimodalConfig
+ # This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints.
+ base_model_prefix = "flava.multimodal_model"
+ main_input_name = "hidden_states"
+
+ def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+ self.use_cls_token = self.config.use_cls_token
+ if self.use_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+
+ self.encoder = FlavaEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FlavaPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
+ )
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
+ )
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, seq_length, _ = hidden_states.size()
+
+ if self.use_cls_token:
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
+ seq_length += 1
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states.device
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
+)
+class FlavaModel(FlavaPreTrainedModel):
+ config_class = FlavaConfig
+
+ def __init__(self, config: FlavaConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, FlavaTextConfig):
+ raise ValueError(
+ "config.text_config is expected to be of type FlavaTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.image_config, FlavaImageConfig):
+ raise ValueError(
+ "config.image_config is expected to be of type FlavaImageConfig but is of type"
+ f" {type(config.image_config)}."
+ )
+
+ if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
+ raise ValueError(
+ "config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
+ + f"is of type {type(config.multimodal_config)}."
+ )
+
+ text_config = config.text_config
+ image_config = config.image_config
+ multimodal_config = config.multimodal_config
+
+ self.projection_dim = config.projection_dim
+ self.text_hidden_size = text_config.hidden_size
+ self.image_hidden_size = image_config.hidden_size
+ self.mm_hidden_size = multimodal_config.hidden_size
+
+ self.text_model = FlavaTextModel(text_config)
+ self.image_model = FlavaImageModel(image_config)
+ self.multimodal_model = FlavaMultimodalModel(multimodal_config)
+
+ self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
+ self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
+
+ self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
+ self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`FlavaTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlavaProcessor, FlavaModel
+
+ >>> model = FlavaModel.from_pretrained("{0}")
+ >>> processor = FlavaProcessor.from_pretrained("{0}")
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
+ ... )
+ >>> text_features = model.get_text_features(**inputs)
+ ```""".format(
+ _CHECKPOINT_FOR_DOC
+ )
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_outputs[0] # last_hidden_state
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`FlavaImageModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FlavaProcessor, FlavaModel
+
+ >>> model = FlavaModel.from_pretrained("{0}")
+ >>> processor = FlavaProcessor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```""".format(
+ _CHECKPOINT_FOR_DOC
+ )
+ image_outputs = self.image_model(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = image_outputs[0] # last_hidden_state
+ image_features = self.image_projection(pooled_output)
+
+ return image_features
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
+ )
+ @replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ skip_multimodal_encoder: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, FlavaOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FlavaProcessor, FlavaModel
+
+ >>> model = FlavaModel.from_pretrained("facebook/flava-full")
+ >>> processor = FlavaProcessor.from_pretrained("facebook/flava-full")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.contrastive_logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+ if not output_hidden_states:
+ raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`")
+ image_embeddings = None
+ image_states = None
+ image_mm_projection = None
+ image_output = None
+ if pixel_values is not None:
+ image_output = self.image_model(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ image_embeddings, image_states = image_output[0], image_output[2]
+ # Note that these states don't use final layernorm in the transformer model
+ image_mm_projection = self.image_to_mm_projection(image_states[-1])
+
+ text_embeddings = None
+ text_states = None
+ text_mm_projection = None
+ text_output = None
+ if input_ids is not None:
+ text_output = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeddings, text_states = text_output[0], text_output[2]
+ # Note that these states don't use final layernorm in the transformer model
+ text_mm_projection = self.text_to_mm_projection(text_states[-1])
+
+ multimodal_embeddings = None
+ multimodal_output = None
+ if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
+ multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
+ multimodal_output = self.multimodal_model(multimodal_input, return_dict=return_dict)
+ multimodal_embeddings = multimodal_output[0]
+
+ if not return_dict:
+ return (
+ image_embeddings,
+ image_output,
+ text_embeddings,
+ text_output,
+ multimodal_embeddings,
+ multimodal_output,
+ )
+
+ return FlavaModelOutput(
+ image_embeddings=image_embeddings,
+ image_output=image_output,
+ text_embeddings=text_embeddings,
+ text_output=text_output,
+ multimodal_embeddings=multimodal_embeddings,
+ multimodal_output=multimodal_output,
+ )
+
+
+class FlavaImageCodebookResPath(nn.Module):
+ def __init__(self, in_size: int, out_size: int, **kwargs):
+ super().__init__()
+ hid_size = out_size // 4
+
+ path = OrderedDict()
+ path["relu_1"] = nn.ReLU()
+ path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)
+ path["relu_2"] = nn.ReLU()
+ path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
+ path["relu_3"] = nn.ReLU()
+ path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
+ path["relu_4"] = nn.ReLU()
+ path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)
+
+ self.path = nn.Sequential(path)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.path(x)
+
+
+class FlavaImageCodebookBlock(nn.Module):
+ def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
+ super().__init__()
+
+ self.post_gain = 1 / (num_layers**2)
+
+ if in_size != out_size:
+ self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
+ else:
+ self.id_path = nn.Identity()
+
+ self.res_path = FlavaImageCodebookResPath(in_size, out_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.id_path(x) + self.post_gain * self.res_path(x)
+
+
+class FlavaImageCodebookLayerGroup(nn.Module):
+ def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
+ super().__init__()
+ blocks = OrderedDict()
+ for i in range(num_blocks):
+ if i == 0:
+ blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
+ else:
+ blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
+
+ if use_pool:
+ blocks["pool"] = nn.MaxPool2d(kernel_size=2)
+
+ self.group = nn.Sequential(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.group(x)
+
+
+# Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
+@add_start_docstrings(
+ """
+ The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
+ to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
+ `get_codebook_indices` to get image tokens for an image.
+ """,
+ FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
+)
+class FlavaImageCodebook(FlavaPreTrainedModel):
+ base_model_prefix = ""
+ config_class = FlavaImageCodebookConfig
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = False
+
+ def __init__(
+ self,
+ config: FlavaImageCodebookConfig,
+ **kwargs: Any,
+ ):
+ super().__init__(config)
+
+ self.config = config
+ self.num_groups = config.num_groups
+ self.input_channels = config.input_channels
+ self.num_blocks_per_group = config.num_blocks_per_group
+ self.hidden_size = config.hidden_size
+ self.vocab_size = config.vocab_size
+
+ num_layers = self.num_groups * self.num_blocks_per_group
+
+ output_blocks = OrderedDict()
+ output_blocks["relu"] = nn.ReLU()
+ output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)
+
+ blocks = OrderedDict()
+ blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)
+ blocks["group_1"] = FlavaImageCodebookLayerGroup(
+ self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
+ )
+ blocks["group_2"] = FlavaImageCodebookLayerGroup(
+ self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
+ )
+ blocks["group_3"] = FlavaImageCodebookLayerGroup(
+ self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
+ )
+ blocks["group_4"] = FlavaImageCodebookLayerGroup(
+ self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
+ )
+ blocks["output"] = nn.Sequential(output_blocks)
+
+ self.blocks = nn.Sequential(blocks)
+
+ self.post_init()
+
+ if self.config.freeze:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Codebook pixel values can be obtained using [`FlavaFeatureExtractor`] by passing
+ `return_codebook_pixels=True`. See [`FlavaFeatureExtractor.__call__`] for details.
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FlavaFeatureExtractor, FlavaImageCodebook
+
+ >>> model = FlavaImageCodebook.from_pretrained("{0}")
+ >>> feature_extractor = FlavaFeatureExtractor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = feature_extractor([image], return_codebook_pixels=True, return_tensors="pt")
+ >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
+
+ >>> outputs = model.get_codebook_indices(**inputs)
+ ```
+ """.format(
+ _CHECKPOINT_FOR_CODEBOOK_DOC
+ )
+ z_logits = self.blocks(pixel_values)
+ return torch.argmax(z_logits, axis=1)
+
+ def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ z_logits = self.blocks(pixel_values)
+ return nn.Softmax(dim=1)(z_logits)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Codebook pixel values can be obtained using [`FlavaFeatureExtractor`] by passing
+ `return_codebook_pixels=True`. See [`FlavaFeatureExtractor.__call__`] for details.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FlavaFeatureExtractor, FlavaImageCodebook
+
+ >>> model = FlavaImageCodebook.from_pretrained("{0}")
+ >>> feature_extractor = FlavaFeatureExtractor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = feature_extractor([image], return_codebook_pixels=True, return_tensors="pt")
+ >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
+
+ >>> outputs = model(**inputs)
+ >>> print(outputs.shape)
+ (1, 196)
+ ```
+ """.format(
+ _CHECKPOINT_FOR_CODEBOOK_DOC
+ )
+ if len(pixel_values.shape) != 4:
+ raise ValueError(f"input shape {pixel_values.shape} is not 4d")
+ if pixel_values.shape[1] != self.input_channels:
+ raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
+ return self.blocks(pixel_values)
+
+
+class FlavaPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class FlavaMaskedPredictionHead(nn.Module):
+ def __init__(self, config, weight=None):
+ super().__init__()
+ self.config = config
+ self.transform = FlavaPredictionHeadTransform(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ if weight is not None:
+ self.decoder.weight = weight
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, x):
+ x = self.transform(x)
+ x = self.decoder(x)
+ return x
+
+
+class FlavaITMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pooler = FlavaPooler(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, x):
+ x = self.pooler(x)
+ x = self.seq_relationship(x)
+ return x
+
+
+class FlavaGlobalContrastiveHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.global_backprop_contrastive = config.global_backprop_contrastive
+
+ def forward(self, image_embeddings, text_embeddings, logit_scale):
+ temperature = torch.exp(logit_scale)
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
+ image_embeddings_all = [image_embeddings]
+ text_embeddings_all = [text_embeddings]
+ else:
+ local_batch_size = image_embeddings.size(0)
+ world_size = torch.distributed.get_world_size()
+
+ if self.global_backprop_contrastive:
+ image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings)
+ text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings)
+ else:
+ image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
+ text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
+ torch.distributed.all_gather(image_embeddings_all, image_embeddings)
+ torch.distributed.all_gather(text_embeddings_all, text_embeddings)
+
+ labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
+ local_batch_size, device=image_embeddings.device
+ )
+
+ image_embeddings_all = torch.cat(image_embeddings_all)
+ text_embeddings_all = torch.cat(text_embeddings_all)
+
+ logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
+ logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
+
+ return logits_per_image, logits_per_text, labels
+
+
+@add_start_docstrings(
+ """
+ The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
+ """,
+ FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
+)
+class FlavaForPreTraining(FlavaPreTrainedModel):
+ def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
+ super().__init__(config)
+ self.flava = FlavaModel(config)
+
+ self.image_codebook = image_codebook
+ if self.image_codebook is None and config.init_codebook:
+ self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
+
+ # Levarage text and image encoder configs to create the masked
+ # head since it has the right vocab
+ self.mim_head = FlavaMaskedPredictionHead(config.image_config)
+ self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
+ self.itm_head = FlavaITMHead(config)
+ self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
+ self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
+ self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
+
+ self.image_vocab_size = config.image_config.vocab_size
+ self.text_vocab_size = config.text_config.vocab_size
+ self.mlm_weight = config.mlm_weight
+ self.mim_weight = config.mim_weight
+ self.global_contrastive_weight = config.global_contrastive_weight
+ self.ce_ignore_index = config.ce_ignore_index
+ self.itm_weight = config.itm_weight
+ self.mmm_image_weight = config.mmm_image_weight
+ self.mmm_text_weight = config.mmm_text_weight
+ self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
+
+ self.post_init()
+
+ def _resize_to_2d(self, x: torch.Tensor):
+ if x.dim() > 2:
+ x = x.view(x.size(0), -1)
+ return x
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
+ )
+ @replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_ids_masked: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ codebook_pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ skip_unmasked_multimodal_encoder: bool = None,
+ mlm_labels: Optional[torch.Tensor] = None,
+ mim_labels: Optional[torch.Tensor] = None,
+ itm_labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: bool = True,
+ return_dict: Optional[bool] = None,
+ return_loss: Optional[bool] = None,
+ ):
+ """
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FlavaForPreTraining, FlavaProcessor
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
+ >>> processor = FlavaProcessor.from_pretrained("facebook/flava-full")
+
+ >>> text = ["a photo of a cat"]
+
+ >>> inputs = processor(
+ ... images=[image],
+ ... text=text,
+ ... return_masks=True,
+ ... return_codebook_pixels=True,
+ ... padding=True,
+ ... max_length=77,
+ ... return_tensors="pt",
+ ... )
+
+
+ >>> output = model(**inputs)
+ ```
+
+ Return:
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ return_loss = return_loss if return_loss is not None else self.config.return_loss
+
+ skip_unmasked_multimodal_encoder = (
+ skip_unmasked_multimodal_encoder
+ if skip_unmasked_multimodal_encoder is not None
+ else self.skip_unmasked_multimodal_encoder
+ )
+
+ if input_ids_masked is None and input_ids is not None:
+ logger.warning(
+ "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
+ " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
+ " you are doing inference on unmasked text..."
+ )
+ input_ids_masked = input_ids
+
+ flava_output = self.flava(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ image_attention_mask=image_attention_mask,
+ # Don't need unmasked multimodal embedding for anything so skip it
+ # NOTE: ITM uses masked version
+ skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ # Pass true to have deterministic outputs
+ return_dict=True,
+ )
+
+ flava_masked_output = self.flava(
+ input_ids=input_ids_masked,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ image_attention_mask=image_attention_mask,
+ bool_masked_pos=bool_masked_pos,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ pos_mask = None
+
+ image_embeddings = flava_output.image_embeddings
+ text_embeddings = flava_output.text_embeddings
+ image_masked_embeddings = flava_masked_output.image_embeddings
+ text_masked_embeddings = flava_masked_output.text_embeddings
+ multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
+
+ total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
+ mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
+ itm_logits = logits_per_image = logits_per_text = None
+
+ # Calculate mim_labels if necessary from the image_codebook
+ if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
+ if mim_labels is None and return_loss:
+ if self.image_codebook is None:
+ raise RuntimeError(
+ "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
+ " have been passed. Reinstantiate the model with `init_codebook` set to True or "
+ "pass in your custom `mim_labels`"
+ )
+ if codebook_pixel_values is None:
+ raise ValueError(
+ "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
+ "Call `FlavaProcessor` with `return_codebook_pixels` set to True"
+ )
+ mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
+ # Unimodal MIM Loss
+ # If multimodal embeddings are present, we will calculate MMM loss
+ if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
+ sequence_for_image = image_masked_embeddings
+
+ if mim_labels is not None:
+ mim_labels = self._resize_to_2d(mim_labels)
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
+
+ sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
+ mim_labels_filtered = mim_labels[masked_tokens]
+ sequence_for_image = sequence_for_image[masked_tokens, :]
+ mim_logits = self.mim_head(sequence_for_image)
+ if return_loss:
+ mim_loss = nn.functional.cross_entropy(
+ mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
+ )
+ mim_loss *= self.mim_weight
+ else:
+ mim_logits = self.mim_head(sequence_for_image)
+
+ # Unimodal MLM Loss
+ if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
+ sequence_for_text = text_masked_embeddings
+ if mlm_labels is not None:
+ mlm_labels = self._resize_to_2d(mlm_labels)
+ sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
+ mlm_labels_filtered = mlm_labels[masked_tokens]
+ sequence_for_text = sequence_for_text[masked_tokens, :]
+ mlm_logits = self.mlm_head(sequence_for_text)
+ if return_loss:
+ mlm_loss = nn.functional.cross_entropy(
+ mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
+ )
+ mlm_loss *= self.mlm_weight
+ else:
+ mlm_logits = self.mlm_head(sequence_for_text)
+
+ # ITM Loss
+ if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
+ itm_logits = self.itm_head(multimodal_masked_embeddings)
+
+ if itm_labels is not None:
+ pos_pairs = itm_labels.ne(0)
+ pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
+ if return_loss:
+ itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
+ itm_loss *= self.itm_weight
+
+ if multimodal_masked_embeddings is not None:
+ multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
+
+ if mlm_labels is not None:
+ mlm_labels = mlm_labels[pos_mask]
+
+ if mim_labels is not None:
+ mim_labels = mim_labels[pos_mask]
+
+ # MMM Image Loss
+ if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
+ sequence_for_image = multimodal_masked_embeddings
+ end_index = image_masked_embeddings.size(1) - 1
+ sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
+
+ if pos_mask is not None:
+ sequence_for_image = sequence_for_image[pos_mask]
+ if mim_labels is not None:
+ mim_labels = self._resize_to_2d(mim_labels)
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
+
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
+ mim_labels_filtered = mim_labels[masked_tokens]
+ sequence_for_image = sequence_for_image[masked_tokens, :]
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
+ if return_loss:
+ mmm_image_loss = nn.functional.cross_entropy(
+ mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
+ )
+ mmm_image_loss *= self.mmm_image_weight
+ else:
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
+
+ # MMM Text Loss
+ if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
+ sequence_for_text = multimodal_masked_embeddings
+ sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
+ if pos_mask is not None:
+ sequence_for_text = sequence_for_text[pos_mask]
+
+ if mlm_labels is not None:
+ mlm_labels = self._resize_to_2d(mlm_labels)
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
+ mlm_labels_filtered = mlm_labels[masked_tokens]
+ sequence_for_text = sequence_for_text[masked_tokens, :]
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
+ if return_loss:
+ mmm_text_loss = nn.functional.cross_entropy(
+ mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
+ )
+ mmm_text_loss *= self.mmm_text_weight
+ else:
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
+
+ # Global Contrastive Loss
+ if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
+ text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
+ text_embedding = nn.functional.normalize(text_embedding, dim=-1)
+
+ image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
+ image_embedding = nn.functional.normalize(image_embedding, dim=-1)
+
+ self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
+
+ logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
+ image_embedding, text_embedding, self.flava.logit_scale
+ )
+
+ # Apply ITM negative mask if any
+ if pos_mask is not None:
+ logits_per_image = logits_per_image[pos_mask]
+ logits_per_text = logits_per_text[pos_mask]
+ gc_labels = gc_labels[pos_mask]
+
+ if return_loss:
+ gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
+ gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
+ gc_loss = (gc_loss_image + gc_loss_text) / 2
+ gc_loss *= self.global_contrastive_weight
+
+ flava_losses = FlavaLosses(
+ mim=mim_loss,
+ mlm=mlm_loss,
+ itm=itm_loss,
+ global_contrastive=gc_loss,
+ mmm_image=mmm_image_loss,
+ mmm_text=mmm_text_loss,
+ )
+
+ if return_loss and not flava_losses.all_none():
+ total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())
+
+ if not return_dict:
+ output = (
+ image_embeddings,
+ flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
+ text_embeddings,
+ flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
+ flava_output.multimodal_embeddings,
+ flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
+ image_masked_embeddings,
+ flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
+ text_masked_embeddings,
+ flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
+ multimodal_masked_embeddings,
+ flava_masked_output.multimodal_output.to_tuple()
+ if flava_masked_output.multimodal_output is not None
+ else None,
+ mim_logits,
+ mlm_logits,
+ itm_logits,
+ logits_per_image,
+ logits_per_image,
+ mmm_image_logits,
+ mmm_text_logits,
+ )
+ if return_loss and not flava_losses.all_none():
+ output = (
+ total_loss,
+ flava_losses,
+ ) + output
+
+ # Filter None as transformer by default won't handle it
+ return tuple(x for x in output if x is None)
+
+ return FlavaForPreTrainingOutput(
+ loss=total_loss,
+ loss_info=flava_losses,
+ image_embeddings=image_embeddings,
+ image_output=flava_output.image_output,
+ text_embeddings=text_embeddings,
+ text_output=flava_output.text_output,
+ multimodal_embeddings=flava_output.multimodal_embeddings,
+ multimodal_output=flava_output.multimodal_output,
+ image_masked_embeddings=image_masked_embeddings,
+ image_masked_output=flava_masked_output.image_output,
+ text_masked_embeddings=text_masked_embeddings,
+ text_masked_output=flava_masked_output.text_output,
+ multimodal_masked_embeddings=multimodal_masked_embeddings,
+ multimodal_masked_output=flava_masked_output.multimodal_output,
+ mim_logits=mim_logits,
+ mlm_logits=mlm_logits,
+ itm_logits=itm_logits,
+ contrastive_logits_per_image=logits_per_image,
+ contrastive_logits_per_text=logits_per_text,
+ mmm_image_logits=mmm_image_logits,
+ mmm_text_logits=mmm_text_logits,
+ )
diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py
new file mode 100644
index 00000000000000..ca2fa094a8d636
--- /dev/null
+++ b/src/transformers/models/flava/processing_flava.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for FLAVA
+"""
+from typing import List, Optional, Union
+
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from ...utils import TensorType
+
+
+class FlavaProcessor(ProcessorMixin):
+ r"""
+ Constructs a FLAVA processor which wraps a FLAVA feature extractor and a FLAVA tokenizer into a single processor.
+
+ [`FlavaProcessor`] offers all the functionalities of [`FlavaFeatureExtractor`] and [`BertTokenizerFast`]. See the
+ [`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor ([`FlavaFeatureExtractor`]): The feature extractor is a required input.
+ tokenizer ([`BertTokenizerFast`]): The tokenizer is a required input.
+ """
+ feature_extractor_class = "FlavaFeatureExtractor"
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_image_mask: Optional[bool] = None,
+ return_codebook_pixels: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ):
+ """
+ This method uses [`FLAVAFeatureExtractor.__call__`] method to prepare image(s) for the model, and
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
+
+ Please refer to the docstring of the above two methods for more information.
+ """
+
+ if text is None and images is None:
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
+
+ if text is not None:
+ encoding = self.tokenizer(
+ text=text,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+ if images is not None:
+ image_features = self.feature_extractor(
+ images,
+ return_image_mask=return_image_mask,
+ return_codebook_pixels=return_codebook_pixels,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ if text is not None and images is not None:
+ encoding.update(image_features)
+ return encoding
+ elif text is not None:
+ return encoding
+ else:
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/fnet/__init__.py b/src/transformers/models/fnet/__init__.py
index 7b09e97ab199b8..7cece0488f635a 100644
--- a/src/transformers/models/fnet/__init__.py
+++ b/src/transformers/models/fnet/__init__.py
@@ -17,18 +17,39 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"],
- "tokenization_fnet": ["FNetTokenizer"],
-}
+_import_structure = {"configuration_fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"]}
-if is_tokenizers_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_fnet"] = ["FNetTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_fnet_fast"] = ["FNetTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_fnet"] = [
"FNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"FNetForMaskedLM",
@@ -46,12 +67,29 @@
if TYPE_CHECKING:
from .configuration_fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig
- from .tokenization_fnet import FNetTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_fnet import FNetTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_fnet_fast import FNetTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_fnet import (
FNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FNetForMaskedLM,
diff --git a/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py
index ffb5667f843f53..27b6563e5dd970 100644
--- a/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py
+++ b/src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py
@@ -147,8 +147,10 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, fnet_config_file, s
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained FNet model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained FNet model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument("--save_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py
index 3c301727a654c1..8ed67182319ff9 100755
--- a/src/transformers/models/fnet/modeling_fnet.py
+++ b/src/transformers/models/fnet/modeling_fnet.py
@@ -182,7 +182,8 @@ def _init_fourier_transform(self, config):
)
else:
logging.warning(
- "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier transform instead."
+ "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
+ " transform instead."
)
self.fourier_transform = fftn
else:
@@ -580,7 +581,8 @@ def forward(
and self.config.tpu_short_seq_length != seq_length
):
raise ValueError(
- "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to the model when using TPU optimizations."
+ "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
+ " the model when using TPU optimizations."
)
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -837,7 +839,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
diff --git a/src/transformers/models/fsmt/__init__.py b/src/transformers/models/fsmt/__init__.py
index 034c2c8d2ac9c0..00a17147adb264 100644
--- a/src/transformers/models/fsmt/__init__.py
+++ b/src/transformers/models/fsmt/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_fsmt": ["FSMTTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_fsmt"] = ["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]
@@ -34,7 +39,12 @@
from .configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from .tokenization_fsmt import FSMTTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel
else:
diff --git a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
index 7257f7faa26617..85f5290a9ebd21 100755
--- a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
@@ -269,7 +269,10 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
default=None,
type=str,
required=True,
- help="Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts, bpecodes, etc.",
+ help=(
+ "Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
+ " bpecodes, etc."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py
index 14823c4352df49..937b8a71282160 100644
--- a/src/transformers/models/fsmt/modeling_fsmt.py
+++ b/src/transformers/models/fsmt/modeling_fsmt.py
@@ -738,9 +738,10 @@ def forward(
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
diff --git a/src/transformers/models/fsmt/tokenization_fsmt.py b/src/transformers/models/fsmt/tokenization_fsmt.py
index 2d1136143948a2..34272e53cf0fcb 100644
--- a/src/transformers/models/fsmt/tokenization_fsmt.py
+++ b/src/transformers/models/fsmt/tokenization_fsmt.py
@@ -21,8 +21,6 @@
import unicodedata
from typing import Dict, List, Optional, Tuple
-import sacremoses as sm
-
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@@ -212,6 +210,16 @@ def __init__(
**kwargs,
)
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
+
self.src_vocab_file = src_vocab_file
self.tgt_vocab_file = tgt_vocab_file
self.merges_file = merges_file
@@ -254,13 +262,13 @@ def vocab_size(self) -> int:
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
- punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
+ punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
return self.cache_moses_punct_normalizer[lang].normalize(text)
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
- moses_tokenizer = sm.MosesTokenizer(lang=lang)
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
return self.cache_moses_tokenizer[lang].tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=True
@@ -268,7 +276,7 @@ def moses_tokenize(self, text, lang):
def moses_detokenize(self, tokens, lang):
if lang not in self.cache_moses_tokenizer:
- moses_detokenizer = sm.MosesDetokenizer(lang=self.tgt_lang)
+ moses_detokenizer = self.sm.MosesDetokenizer(lang=self.tgt_lang)
self.cache_moses_detokenizer[lang] = moses_detokenizer
return self.cache_moses_detokenizer[lang].detokenize(tokens)
@@ -497,11 +505,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(src_vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
tgt_vocab = {v: k for k, v in self.decoder.items()}
- f.write(json.dumps(tgt_vocab, ensure_ascii=False))
+ f.write(json.dumps(tgt_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merges_file, "w", encoding="utf-8") as writer:
@@ -516,3 +524,21 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1
return src_vocab_file, tgt_vocab_file, merges_file
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sm"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
diff --git a/src/transformers/models/funnel/__init__.py b/src/transformers/models/funnel/__init__.py
index b9c6b9608d3787..6a9f6073fad5b8 100644
--- a/src/transformers/models/funnel/__init__.py
+++ b/src/transformers/models/funnel/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -27,10 +33,20 @@
"tokenization_funnel": ["FunnelTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_funnel_fast"] = ["FunnelTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_funnel"] = [
"FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST",
"FunnelBaseModel",
@@ -45,7 +61,12 @@
"load_tf_weights_in_funnel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_funnel"] = [
"TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFFunnelBaseModel",
@@ -64,10 +85,20 @@
from .configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from .tokenization_funnel import FunnelTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_funnel_fast import FunnelTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_funnel import (
FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
FunnelBaseModel,
@@ -82,7 +113,12 @@
load_tf_weights_in_funnel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_funnel import (
TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFunnelBaseModel,
diff --git a/src/transformers/models/funnel/configuration_funnel.py b/src/transformers/models/funnel/configuration_funnel.py
index 5684427cb7a702..c792b05638d74f 100644
--- a/src/transformers/models/funnel/configuration_funnel.py
+++ b/src/transformers/models/funnel/configuration_funnel.py
@@ -25,8 +25,12 @@
"funnel-transformer/small-base": "https://huggingface.co/funnel-transformer/small-base/resolve/main/config.json",
"funnel-transformer/medium": "https://huggingface.co/funnel-transformer/medium/resolve/main/config.json",
"funnel-transformer/medium-base": "https://huggingface.co/funnel-transformer/medium-base/resolve/main/config.json",
- "funnel-transformer/intermediate": "https://huggingface.co/funnel-transformer/intermediate/resolve/main/config.json",
- "funnel-transformer/intermediate-base": "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/config.json",
+ "funnel-transformer/intermediate": (
+ "https://huggingface.co/funnel-transformer/intermediate/resolve/main/config.json"
+ ),
+ "funnel-transformer/intermediate-base": (
+ "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/config.json"
+ ),
"funnel-transformer/large": "https://huggingface.co/funnel-transformer/large/resolve/main/config.json",
"funnel-transformer/large-base": "https://huggingface.co/funnel-transformer/large-base/resolve/main/config.json",
"funnel-transformer/xlarge": "https://huggingface.co/funnel-transformer/xlarge/resolve/main/config.json",
diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py
index b13d6dcd1007a7..848101f083582b 100755
--- a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained model. \n"
- "This specifies the model architecture.",
+ help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py
index 267d32f2a47a6e..5caee872dcb068 100644
--- a/src/transformers/models/funnel/modeling_funnel.py
+++ b/src/transformers/models/funnel/modeling_funnel.py
@@ -671,7 +671,7 @@ def forward(
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
hidden, attention_inputs
)
- for (layer_index, layer) in enumerate(block):
+ for layer_index, layer in enumerate(block):
for repeat_index in range(self.config.block_repeats[block_index]):
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
if do_pooling:
diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py
index 4e4f95d850eb79..92a4453d1cbe92 100644
--- a/src/transformers/models/funnel/modeling_tf_funnel.py
+++ b/src/transformers/models/funnel/modeling_tf_funnel.py
@@ -623,7 +623,7 @@ def call(
hidden, attention_inputs
)
- for (layer_index, layer) in enumerate(block):
+ for layer_index, layer in enumerate(block):
for repeat_index in range(self.block_repeats[block_index]):
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
if do_pooling:
diff --git a/src/transformers/models/funnel/tokenization_funnel.py b/src/transformers/models/funnel/tokenization_funnel.py
index bb8b7548e96a83..250d0d51da4744 100644
--- a/src/transformers/models/funnel/tokenization_funnel.py
+++ b/src/transformers/models/funnel/tokenization_funnel.py
@@ -42,13 +42,21 @@
"funnel-transformer/small": "https://huggingface.co/funnel-transformer/small/resolve/main/vocab.txt",
"funnel-transformer/small-base": "https://huggingface.co/funnel-transformer/small-base/resolve/main/vocab.txt",
"funnel-transformer/medium": "https://huggingface.co/funnel-transformer/medium/resolve/main/vocab.txt",
- "funnel-transformer/medium-base": "https://huggingface.co/funnel-transformer/medium-base/resolve/main/vocab.txt",
- "funnel-transformer/intermediate": "https://huggingface.co/funnel-transformer/intermediate/resolve/main/vocab.txt",
- "funnel-transformer/intermediate-base": "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/vocab.txt",
+ "funnel-transformer/medium-base": (
+ "https://huggingface.co/funnel-transformer/medium-base/resolve/main/vocab.txt"
+ ),
+ "funnel-transformer/intermediate": (
+ "https://huggingface.co/funnel-transformer/intermediate/resolve/main/vocab.txt"
+ ),
+ "funnel-transformer/intermediate-base": (
+ "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/vocab.txt"
+ ),
"funnel-transformer/large": "https://huggingface.co/funnel-transformer/large/resolve/main/vocab.txt",
"funnel-transformer/large-base": "https://huggingface.co/funnel-transformer/large-base/resolve/main/vocab.txt",
"funnel-transformer/xlarge": "https://huggingface.co/funnel-transformer/xlarge/resolve/main/vocab.txt",
- "funnel-transformer/xlarge-base": "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/vocab.txt",
+ "funnel-transformer/xlarge-base": (
+ "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/vocab.txt"
+ ),
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {f"funnel-transformer/{name}": 512 for name in _model_names}
diff --git a/src/transformers/models/funnel/tokenization_funnel_fast.py b/src/transformers/models/funnel/tokenization_funnel_fast.py
index 9fa7335ea5a38d..159184bf4ba1ca 100644
--- a/src/transformers/models/funnel/tokenization_funnel_fast.py
+++ b/src/transformers/models/funnel/tokenization_funnel_fast.py
@@ -43,25 +43,45 @@
"funnel-transformer/small": "https://huggingface.co/funnel-transformer/small/resolve/main/vocab.txt",
"funnel-transformer/small-base": "https://huggingface.co/funnel-transformer/small-base/resolve/main/vocab.txt",
"funnel-transformer/medium": "https://huggingface.co/funnel-transformer/medium/resolve/main/vocab.txt",
- "funnel-transformer/medium-base": "https://huggingface.co/funnel-transformer/medium-base/resolve/main/vocab.txt",
- "funnel-transformer/intermediate": "https://huggingface.co/funnel-transformer/intermediate/resolve/main/vocab.txt",
- "funnel-transformer/intermediate-base": "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/vocab.txt",
+ "funnel-transformer/medium-base": (
+ "https://huggingface.co/funnel-transformer/medium-base/resolve/main/vocab.txt"
+ ),
+ "funnel-transformer/intermediate": (
+ "https://huggingface.co/funnel-transformer/intermediate/resolve/main/vocab.txt"
+ ),
+ "funnel-transformer/intermediate-base": (
+ "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/vocab.txt"
+ ),
"funnel-transformer/large": "https://huggingface.co/funnel-transformer/large/resolve/main/vocab.txt",
"funnel-transformer/large-base": "https://huggingface.co/funnel-transformer/large-base/resolve/main/vocab.txt",
"funnel-transformer/xlarge": "https://huggingface.co/funnel-transformer/xlarge/resolve/main/vocab.txt",
- "funnel-transformer/xlarge-base": "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/vocab.txt",
+ "funnel-transformer/xlarge-base": (
+ "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
"funnel-transformer/small": "https://huggingface.co/funnel-transformer/small/resolve/main/tokenizer.json",
- "funnel-transformer/small-base": "https://huggingface.co/funnel-transformer/small-base/resolve/main/tokenizer.json",
+ "funnel-transformer/small-base": (
+ "https://huggingface.co/funnel-transformer/small-base/resolve/main/tokenizer.json"
+ ),
"funnel-transformer/medium": "https://huggingface.co/funnel-transformer/medium/resolve/main/tokenizer.json",
- "funnel-transformer/medium-base": "https://huggingface.co/funnel-transformer/medium-base/resolve/main/tokenizer.json",
- "funnel-transformer/intermediate": "https://huggingface.co/funnel-transformer/intermediate/resolve/main/tokenizer.json",
- "funnel-transformer/intermediate-base": "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/tokenizer.json",
+ "funnel-transformer/medium-base": (
+ "https://huggingface.co/funnel-transformer/medium-base/resolve/main/tokenizer.json"
+ ),
+ "funnel-transformer/intermediate": (
+ "https://huggingface.co/funnel-transformer/intermediate/resolve/main/tokenizer.json"
+ ),
+ "funnel-transformer/intermediate-base": (
+ "https://huggingface.co/funnel-transformer/intermediate-base/resolve/main/tokenizer.json"
+ ),
"funnel-transformer/large": "https://huggingface.co/funnel-transformer/large/resolve/main/tokenizer.json",
- "funnel-transformer/large-base": "https://huggingface.co/funnel-transformer/large-base/resolve/main/tokenizer.json",
+ "funnel-transformer/large-base": (
+ "https://huggingface.co/funnel-transformer/large-base/resolve/main/tokenizer.json"
+ ),
"funnel-transformer/xlarge": "https://huggingface.co/funnel-transformer/xlarge/resolve/main/tokenizer.json",
- "funnel-transformer/xlarge-base": "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/tokenizer.json",
+ "funnel-transformer/xlarge-base": (
+ "https://huggingface.co/funnel-transformer/xlarge-base/resolve/main/tokenizer.json"
+ ),
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {f"funnel-transformer/{name}": 512 for name in _model_names}
diff --git a/src/transformers/models/glpn/__init__.py b/src/transformers/models/glpn/__init__.py
index e758224d7d898f..aa667afff6111f 100644
--- a/src/transformers/models/glpn/__init__.py
+++ b/src/transformers/models/glpn/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"],
-}
+_import_structure = {"configuration_glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_glpn"] = ["GLPNFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_glpn"] = [
"GLPN_PRETRAINED_MODEL_ARCHIVE_LIST",
"GLPNForDepthEstimation",
@@ -41,10 +49,20 @@
if TYPE_CHECKING:
from .configuration_glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_glpn import GLPNFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_glpn import (
GLPN_PRETRAINED_MODEL_ARCHIVE_LIST,
GLPNForDepthEstimation,
diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py
index 2361e8e61b4a25..b7fc18b1d0f9bc 100755
--- a/src/transformers/models/glpn/modeling_glpn.py
+++ b/src/transformers/models/glpn/modeling_glpn.py
@@ -80,7 +80,7 @@ def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py
index a5d5920f884d87..477f0cc8d8bfc1 100644
--- a/src/transformers/models/gpt2/__init__.py
+++ b/src/transformers/models/gpt2/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_gpt2": ["GPT2Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_gpt2"] = [
"GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPT2DoubleHeadsModel",
@@ -41,7 +58,12 @@
"load_tf_weights_in_gpt2",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_gpt2"] = [
"TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFGPT2DoubleHeadsModel",
@@ -52,17 +74,32 @@
"TFGPT2PreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
from .tokenization_gpt2 import GPT2Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_gpt2_fast import GPT2TokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_gpt2 import (
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
GPT2DoubleHeadsModel,
@@ -74,7 +111,12 @@
load_tf_weights_in_gpt2,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGPT2DoubleHeadsModel,
@@ -85,7 +127,12 @@
TFGPT2PreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else:
diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py
index e7f90fbad3f3fe..950fcd2f2c2ef9 100644
--- a/src/transformers/models/gpt2/configuration_gpt2.py
+++ b/src/transformers/models/gpt2/configuration_gpt2.py
@@ -262,8 +262,9 @@ def generate_dummy_inputs(
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
diff --git a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py
index 4d8b465afa66da..066ba06503affd 100755
--- a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py
@@ -60,8 +60,10 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
"--gpt2_config_file",
default="",
type=str,
- help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
- "This specifies the model architecture.",
+ help=(
+ "An optional config json file corresponding to the pre-trained OpenAI model. \n"
+ "This specifies the model architecture."
+ ),
)
args = parser.parse_args()
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 00df05b39063bc..b5872be2815cb6 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -18,7 +18,7 @@
import math
import os
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -146,7 +146,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
@@ -197,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
@@ -289,15 +290,15 @@ def _merge_heads(self, tensor, num_heads, attn_head_size):
def forward(
self,
- hidden_states,
- layer_past=None,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
@@ -350,7 +351,7 @@ def __init__(self, intermediate_size, config):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
@@ -376,15 +377,15 @@ def __init__(self, config, layer_idx=None):
def forward(
self,
- hidden_states,
- layer_past=None,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
@@ -447,6 +448,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
+ _no_split_modules = ["GPT2Block"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -742,20 +744,20 @@ def _prune_heads(self, heads_to_prune):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1020,21 +1022,21 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -1189,22 +1191,22 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- mc_token_ids=None,
- labels=None,
- mc_labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mc_token_ids: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mc_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
**kwargs,
- ):
+ ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
@@ -1326,7 +1328,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -1352,19 +1354,19 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1406,10 +1408,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
@@ -1488,19 +1490,19 @@ def __init__(self, config):
# fmt: on
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py
index 3a11f46bdb3bef..b3d1ad048498e3 100644
--- a/src/transformers/models/gpt2/modeling_tf_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py
@@ -813,25 +813,21 @@ def get_output_embeddings(self):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
- def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
- # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
- # tests will need to be fixed after the change
-
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
+ if token_type_ids is not None:
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
+
+ position_ids = kwargs.get("position_ids", None)
+ attention_mask = kwargs.get("attention_mask", None)
- # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
- # for a future PR to not change too many things for now.
- # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
- position_ids = None
- attention_mask = None
- if use_xla:
- attention_mask = kwargs.get("attention_mask", None)
- if past is not None and attention_mask is not None:
- position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
- elif attention_mask is not None:
- position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
+ if attention_mask is not None and position_ids is None:
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
+ if past:
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
return {
"input_ids": inputs,
@@ -839,6 +835,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x
"position_ids": position_ids,
"past": past,
"use_cache": use_cache,
+ "token_type_ids": token_type_ids,
}
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
@@ -1061,7 +1058,7 @@ def call(
>>> embedding_layer = model.resize_token_embeddings(
... len(tokenizer)
- >>> ) # Update the model embeddings with the new vocabulary size
+ ... ) # Update the model embeddings with the new vocabulary size
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
>>> encoded_choices = [tokenizer.encode(s) for s in choices]
@@ -1240,7 +1237,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py
index 6a6f49b1f9883e..b480eca0c062ce 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2.py
@@ -162,20 +162,26 @@ def __init__(
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
+ pad_token=None,
add_prefix_space=False,
+ add_bos_token=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
+ pad_token=pad_token,
add_prefix_space=add_prefix_space,
+ add_bos_token=add_bos_token,
**kwargs,
)
+ self.add_bos_token = add_bos_token
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
@@ -242,6 +248,19 @@ def bpe(self, token):
self.cache[token] = word
return word
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is None:
+ return output
+
+ return output + bos_token_ids + token_ids_1
+
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
@@ -278,7 +297,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
index e244a5d21e6f20..ddd4ad56fde18a 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
@@ -146,6 +146,17 @@ def __init__(
**kwargs,
)
+ if kwargs.pop("add_bos_token", False):
+ model_id = kwargs.pop("name_or_path", "")
+ raise ValueError(
+ "Currenty GPT2's fast tokenizer does NOT support adding a BOS token."
+ "Instead you should use GPT2's slow tokenizer class `GPT2Tokenizer` as follows: \n"
+ f"`GPT2Tokenizer.from_pretrained('{model_id}')`\nor\n"
+ f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
+ "This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
+ " so that the fast tokenizer works correctly."
+ )
+
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
diff --git a/src/transformers/models/gpt_neo/__init__.py b/src/transformers/models/gpt_neo/__init__.py
index d039b6f43974f4..b57f7c3f9760ac 100644
--- a/src/transformers/models/gpt_neo/__init__.py
+++ b/src/transformers/models/gpt_neo/__init__.py
@@ -17,14 +17,19 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_gpt_neo"] = [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
@@ -34,7 +39,12 @@
"load_tf_weights_in_gpt_neo",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_gpt_neo"] = [
"FlaxGPTNeoForCausalLM",
"FlaxGPTNeoModel",
@@ -45,7 +55,12 @@
if TYPE_CHECKING:
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
@@ -55,7 +70,12 @@
load_tf_weights_in_gpt_neo,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py
index dc47db0a8a1929..00054a2c6bb059 100644
--- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py
@@ -261,8 +261,9 @@ def generate_dummy_inputs(
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
diff --git a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
index 7ee1c17477ebb6..4a5fddd0a9d0f9 100644
--- a/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
+++ b/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py
@@ -60,8 +60,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained mesh-tf model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained mesh-tf model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 37931176a361f3..4e507d8d859803 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -147,15 +147,16 @@ def __init__(self, config, attention_type):
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
- self.attn_dropout = nn.Dropout(config.attention_dropout)
- self.resid_dropout = nn.Dropout(config.resid_dropout)
+ self.attn_dropout = nn.Dropout(float(config.attention_dropout))
+ self.resid_dropout = nn.Dropout(float(config.resid_dropout))
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@@ -187,7 +188,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
@@ -289,7 +290,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 *
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
- self.dropout = nn.Dropout(config.resid_dropout)
+ self.dropout = nn.Dropout(float(config.resid_dropout))
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
@@ -357,6 +358,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt_neo
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTNeoBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -474,7 +476,7 @@ def __init__(self, config):
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
- self.drop = nn.Dropout(config.embed_dropout)
+ self.drop = nn.Dropout(float(config.embed_dropout))
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -508,7 +510,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -659,7 +661,7 @@ def custom_forward(*inputs):
class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"h\.\d+\.attn\.masked_bias",
- r"lm_head\.weight",
+ r"lm_head.weight",
r"h\.\d+\.attn\.attention\.bias",
]
_keys_to_ignore_on_save = [r"lm_head.weight"]
@@ -727,7 +729,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -810,7 +812,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -842,7 +844,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -883,10 +885,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
diff --git a/src/transformers/models/gpt_neox/__init__.py b/src/transformers/models/gpt_neox/__init__.py
new file mode 100644
index 00000000000000..814fa9a301310a
--- /dev/null
+++ b/src/transformers/models/gpt_neox/__init__.py
@@ -0,0 +1,78 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable
+
+
+_import_structure = {"configuration_gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"]}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_gpt_neox_fast"] = ["GPTNeoXTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_gpt_neox"] = [
+ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "GPTNeoXForCausalLM",
+ "GPTNeoXLayer",
+ "GPTNeoXModel",
+ "GPTNeoXPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_gpt_neox import (
+ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
+ GPTNeoXForCausalLM,
+ GPTNeoXLayer,
+ GPTNeoXModel,
+ GPTNeoXPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py
new file mode 100644
index 00000000000000..712ec864b4409a
--- /dev/null
+++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" GPTNeoX model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "EleutherAI/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/config.json",
+ # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox
+}
+
+
+class GPTNeoXConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GPTNeoXModel`]. It is used to instantiate an
+ GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the GPTNeoX
+ [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPTNeoXModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ rotary_pct (`float`, *optional*, defaults to 0.25):
+ percentage of hidden dimensions to allocate to rotary embeddings
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
+ base for computing rotary embeddings frequency
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXModel, GPTNeoXConfig
+
+ >>> # Initializing a GPTNeoX gpt-neox-20b style configuration
+ >>> configuration = GPTNeoXConfig()
+
+ >>> # Initializing a model from the gpt-neox-20b style configuration
+ >>> model = GPTNeoXModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "gpt_neox"
+
+ def __init__(
+ self,
+ vocab_size=50432,
+ hidden_size=6144,
+ num_hidden_layers=44,
+ num_attention_heads=64,
+ intermediate_size=24576,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ rotary_pct=0.25,
+ rotary_emb_base=10000,
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=0,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ **kwargs
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.rotary_pct = rotary_pct
+ self.rotary_emb_base = rotary_emb_base
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
new file mode 100755
index 00000000000000..8a1879a624aec5
--- /dev/null
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -0,0 +1,645 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch GPTNeoX model."""
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from .configuration_gpt_neox import GPTNeoXConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "gpt-neox-20b"
+_CONFIG_FOR_DOC = "GPTNeoXConfig"
+_TOKENIZER_FOR_DOC = "GPTNeoXTokenizerFast"
+
+GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "EleutherAI/gpt-neox-20b",
+ # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox
+]
+
+
+class GPTNeoXPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPTNeoXConfig
+ base_model_prefix = "gpt_neox"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTNeoXLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GPTNeoXModel):
+ module.gradient_checkpointing = value
+
+
+class GPTNeoXAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_attention_heads
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
+ self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=config.rotary_emb_base)
+ self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
+ self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask=None,
+ layer_past=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ has_layer_past = layer_past is not None
+
+ # Compute QKV
+ # Attention heads [batch, seq_len, hidden_size]
+ # --> [batch, seq_len, (np * 3 * head_size)]
+ qkv = self.query_key_value(hidden_states)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = query[..., : self.rotary_ndims]
+ query_pass = query[..., self.rotary_ndims :]
+ key_rot = key[..., : self.rotary_ndims]
+ key_pass = key[..., self.rotary_ndims :]
+
+ # Compute token offset for rotary embeddings (when decoding)
+ seq_len = key.shape[-2]
+ offset = 0
+ if has_layer_past:
+ offset = layer_past[0].shape[-2]
+ seq_len += offset
+ cos, sin = self.rotary_emb(value, seq_len=seq_len)
+ query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
+ query = torch.cat((query, query_pass), dim=-1)
+ key = torch.cat((key, key_pass), dim=-1)
+
+ # Cache QKV values
+ if has_layer_past:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+ present = None if use_cache else (key, value)
+
+ # Compute attention
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ # Reshape outputs
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
+ attn_output = self.dense(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ @classmethod
+ def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Splits hidden dim into attn_head_size and num_attention_heads
+ """
+ # tensor: [bs, seq_len, hidden_size]
+ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(new_shape)
+ # -> [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3)
+ return tensor
+
+ @classmethod
+ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden dim
+ """
+ # tensor [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
+ # -> [bs, seq_len, hidden_size]
+ return tensor
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
+ # compute causal mask from causal mask buffer
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
+ key_length = key.size(-2)
+
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
+ attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
+ attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
+
+ attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_scores = attn_scores + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1)
+ attn_weights = attn_weights.to(value.dtype)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ return attn_output, attn_weights
+
+
+def attention_mask_func(attention_scores, ltor_mask):
+ attention_scores.masked_fill_(~ltor_mask, -10000.0)
+ return attention_scores
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.max_seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.cos_cached = emb.cos()[None, None, :, :]
+ self.sin_cached = emb.sin()[None, None, :, :]
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
+ cos = cos[..., offset : q.shape[-2] + offset, :]
+ sin = sin[..., offset : q.shape[-2] + offset, :]
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class GPTNeoXMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class GPTNeoXLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = GPTNeoXAttention(config)
+ self.mlp = GPTNeoXMLP(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ layer_past=None,
+ output_attentions=False,
+ ):
+ residual = hidden_states
+ ln_out = self.input_layernorm(hidden_states)
+ attention_layer_outputs = self.attention(
+ ln_out,
+ attention_mask=attention_mask,
+ layer_past=layer_past,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attention_layer_outputs[1:]
+
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
+ hidden_states = mlp_output + attn_output + residual
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs # hidden_states, present, (attentions)
+
+
+GPT_NEOX_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`~GPTNeoXConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPT_NEOX_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`GPTNeoXTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT_NEOX_START_DOCSTRING,
+)
+class GPTNeoXModel(GPTNeoXPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_in
+
+ def set_input_embeddings(self, value):
+ self.embed_in = value
+
+ @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
+
+ # Attention mask.
+ if attention_mask is not None:
+ assert batch_size > 0, "batch_size has to be defined and > 0"
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ hidden_states = inputs_embeds
+
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+@add_start_docstrings(
+ """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
+)
+class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
+
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.gpt_neox = GPTNeoXModel(config)
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.embed_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.embed_out = new_embeddings
+
+ @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ past_key_values=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
+ only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPTNeoXTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
+ >>> import torch
+
+ >>> tokenizer = GPTNeoXTokenizer.from_pretrained("gpt-neox-20b")
+ >>> config = GPTNeoXConfig.from_pretrained("gpt-neox-20b")
+ >>> config.is_decoder = True
+ >>> model = GPTNeoXForCausalLM.from_pretrained("gpt-neox-20b", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.gpt_neox(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ lm_logits = self.embed_out(hidden_states)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shift_logits = lm_logits[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
+ )
+ return reordered_past
diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
new file mode 100644
index 00000000000000..c08d533835d708
--- /dev/null
+++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for GPTNeoX."""
+import json
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from tokenizers import pre_tokenizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from transformers.pipelines.conversational import Conversation
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "tokenizer_file": {
+ "EleutherAI/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "gpt-neox-20b": 2048,
+}
+
+
+class GPTNeoXTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" GPT-NeoX-20B tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import GPTNeoXTokenizerFast
+ >>> tokenizer = GPTNeoXTokenizerFast.from_pretrained("gpt2")
+ >>> tokenizer("Hello world")['input_ids']
+ [15496, 995]
+ >>> tokenizer(" Hello world")['input_ids']
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (GPTNeoX tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token="<|endoftext|>",
+ eos_token="<|endoftext|>",
+ add_prefix_space=False,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
+ """This corresponds to DialoGPT variants of models."""
+ input_ids = []
+ for is_user, text in conversation.iter_texts():
+ input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
+
+ if len(input_ids) > self.model_max_length:
+ input_ids = input_ids[-self.model_max_length :]
+ return input_ids
diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py
index a6b144ab82518c..d4c4e01a6ede99 100644
--- a/src/transformers/models/gptj/__init__.py
+++ b/src/transformers/models/gptj/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
-}
+_import_structure = {"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
@@ -34,7 +43,12 @@
"GPTJPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_gptj"] = [
"TFGPTJForCausalLM",
"TFGPTJForQuestionAnswering",
@@ -43,7 +57,12 @@
"TFGPTJPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_gptj"] = [
"FlaxGPTJForCausalLM",
"FlaxGPTJModel",
@@ -54,7 +73,12 @@
if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
@@ -64,7 +88,12 @@
GPTJPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_gptj import (
TFGPTJForCausalLM,
TFGPTJForQuestionAnswering,
@@ -73,7 +102,12 @@
TFGPTJPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
else:
diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py
index 1fb6edd3db8ef8..c1f20a77134bcf 100644
--- a/src/transformers/models/gptj/configuration_gptj.py
+++ b/src/transformers/models/gptj/configuration_gptj.py
@@ -211,8 +211,9 @@ def generate_dummy_inputs(
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
+ mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index 53dc690fd6eb2f..fed2ee12a8c958 100755
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -14,7 +14,7 @@
# limitations under the License.
""" PyTorch GPT-J model."""
-from typing import Tuple
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
- x = torch.stack((-x2, x1), axis=-1)
+ x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
@@ -111,7 +111,8 @@ def __init__(self, config):
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+ f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
@@ -162,7 +163,7 @@ def _attn(
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
@@ -191,13 +192,16 @@ def _attn(
def forward(
self,
- hidden_states,
- attention_mask=None,
- layer_past=None,
- head_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[
+ Tuple[torch.Tensor, Tuple[torch.Tensor]],
+ Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
+ ]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
@@ -271,7 +275,7 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 *
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
@@ -289,13 +293,13 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- layer_past=None,
- attention_mask=None,
- head_mask=None,
- use_cache=False,
- output_attentions=False,
- ):
+ hidden_states: Optional[torch.FloatTensor],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
@@ -330,6 +334,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
+ _no_split_modules = ["GPTJBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -533,18 +538,18 @@ def set_input_embeddings(self, new_embeddings):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -787,19 +792,19 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -885,7 +890,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
GPTJ_START_DOCSTRING,
)
class GPTJForSequenceClassification(GPTJPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -911,19 +916,19 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -964,10 +969,10 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
- pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
@@ -1012,7 +1017,7 @@ def forward(
GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -1038,18 +1043,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py
index feaad22eff0476..6f18848a61cbe5 100644
--- a/src/transformers/models/gptj/modeling_tf_gptj.py
+++ b/src/transformers/models/gptj/modeling_tf_gptj.py
@@ -93,7 +93,8 @@ def __init__(self, config: GPTJConfig, **kwargs):
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+ f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = self.head_dim**0.5
self.rotary_dim = config.rotary_dim
@@ -929,7 +930,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/herbert/__init__.py b/src/transformers/models/herbert/__init__.py
index 4cd458b4e84334..ef9d47535e5f02 100644
--- a/src/transformers/models/herbert/__init__.py
+++ b/src/transformers/models/herbert/__init__.py
@@ -18,21 +18,29 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available
-_import_structure = {
- "tokenization_herbert": ["HerbertTokenizer"],
-}
+_import_structure = {"tokenization_herbert": ["HerbertTokenizer"]}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_herbert_fast"] = ["HerbertTokenizerFast"]
if TYPE_CHECKING:
from .tokenization_herbert import HerbertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_herbert_fast import HerbertTokenizerFast
else:
diff --git a/src/transformers/models/hubert/__init__.py b/src/transformers/models/hubert/__init__.py
index 59f848c1187240..bd415e49a1501e 100644
--- a/src/transformers/models/hubert/__init__.py
+++ b/src/transformers/models/hubert/__init__.py
@@ -17,15 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
-_import_structure = {
- ".wav2vec2.feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
- "configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
-}
+_import_structure = {"configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_hubert"] = [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
@@ -35,7 +37,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_hubert"] = [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
@@ -44,10 +51,14 @@
]
if TYPE_CHECKING:
- from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
@@ -56,7 +67,12 @@
HubertPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py
index 9b104aa9c52883..621537f493b652 100644
--- a/src/transformers/models/hubert/configuration_hubert.py
+++ b/src/transformers/models/hubert/configuration_hubert.py
@@ -233,10 +233,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
index c1963faa73b3b4..d7ba74fedae7b2 100644
--- a/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
+++ b/src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
@@ -51,9 +51,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -121,28 +122,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
index dee823e094d6b5..9a70fb6db710f4 100644
--- a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
@@ -64,9 +64,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -134,28 +135,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index 5af0197fb95cbf..ab29fe13c9365c 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -488,7 +488,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -504,7 +505,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -525,7 +527,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -657,7 +660,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -745,7 +749,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py
index eb79815f1ad793..d659d2cacb52b9 100644
--- a/src/transformers/models/hubert/modeling_tf_hubert.py
+++ b/src/transformers/models/hubert/modeling_tf_hubert.py
@@ -15,6 +15,7 @@
""" TensorFlow Hubert model."""
import inspect
import warnings
+from collections.abc import Mapping
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
@@ -24,7 +25,6 @@
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax
-from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -95,12 +95,14 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[i]] = input
else:
raise ValueError(
- f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
+ f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[i]}."
)
- elif isinstance(input_values, (dict, BatchEncoding)):
+ elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
- "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
+ " instead.",
FutureWarning,
)
@@ -108,7 +110,8 @@ def input_values_processing(func, config, input_values, **kwargs):
if "decoder_cached_states" in input_values:
warnings.warn(
- "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
@@ -128,7 +131,8 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[0]] = input_values
else:
raise ValueError(
- f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
+ f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[0]}."
)
for name in parameter_names:
@@ -219,7 +223,8 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
@@ -408,9 +413,11 @@ def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
- "Axis " + str(self.axis) + " of "
- "input tensor should have a defined dimension "
- "but the layer received an input with shape " + str(input_shape) + "."
+ "Axis "
+ + str(self.axis)
+ + " of input tensor should have a defined dimension but the layer received an input with shape "
+ + str(input_shape)
+ + "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
@@ -424,22 +431,27 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") cannot be "
- "more than the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") cannot be more than the number of channels ("
+ + str(dim)
+ + ")."
)
if dim % self.groups != 0:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") must be a "
- "multiple of the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") must be a multiple of the number of channels ("
+ + str(dim)
+ + ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
- "You are trying to normalize your batch axis. Do you want to "
- "use tf.layer.batch_normalization instead"
+ "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
@@ -809,7 +821,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -819,7 +834,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -835,7 +853,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -852,7 +873,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
diff --git a/src/transformers/models/ibert/__init__.py b/src/transformers/models/ibert/__init__.py
index e941b88f256e94..0480da8c47fe55 100644
--- a/src/transformers/models/ibert/__init__.py
+++ b/src/transformers/models/ibert/__init__.py
@@ -18,14 +18,17 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
-}
+_import_structure = {"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertForMaskedLM",
@@ -40,7 +43,12 @@
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_ibert import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertForMaskedLM,
diff --git a/src/transformers/models/ibert/configuration_ibert.py b/src/transformers/models/ibert/configuration_ibert.py
index 17f6d37e7d465d..32d4d2e56a809e 100644
--- a/src/transformers/models/ibert/configuration_ibert.py
+++ b/src/transformers/models/ibert/configuration_ibert.py
@@ -29,7 +29,9 @@
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"kssteven/ibert-roberta-base": "https://huggingface.co/kssteven/ibert-roberta-base/resolve/main/config.json",
"kssteven/ibert-roberta-large": "https://huggingface.co/kssteven/ibert-roberta-large/resolve/main/config.json",
- "kssteven/ibert-roberta-large-mnli": "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json",
+ "kssteven/ibert-roberta-large-mnli": (
+ "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py
index 420e8b27404704..421dbcae0b16d2 100644
--- a/src/transformers/models/ibert/modeling_ibert.py
+++ b/src/transformers/models/ibert/modeling_ibert.py
@@ -814,7 +814,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
diff --git a/src/transformers/models/ibert/quant_modules.py b/src/transformers/models/ibert/quant_modules.py
index e6eab6ce620146..fa657924645e93 100644
--- a/src/transformers/models/ibert/quant_modules.py
+++ b/src/transformers/models/ibert/quant_modules.py
@@ -150,7 +150,7 @@ def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, c
def __repr__(self):
return (
f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
- f"quant_mode: {self.activation_bit}, Act_min: {self.x_min.item():.2f}, "
+ f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
f"Act_max: {self.x_max.item():.2f})"
)
diff --git a/src/transformers/models/imagegpt/__init__.py b/src/transformers/models/imagegpt/__init__.py
index f82d0cb989ec81..ecf7ba9408d1bf 100644
--- a/src/transformers/models/imagegpt/__init__.py
+++ b/src/transformers/models/imagegpt/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"],
-}
+_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_imagegpt"] = ["ImageGPTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_imagegpt"] = [
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ImageGPTForCausalImageModeling",
@@ -42,10 +50,20 @@
if TYPE_CHECKING:
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_imagegpt import ImageGPTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_imagegpt import (
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
ImageGPTForCausalImageModeling,
diff --git a/src/transformers/models/imagegpt/configuration_imagegpt.py b/src/transformers/models/imagegpt/configuration_imagegpt.py
index d52414abfd3f32..e9cf1d910d9f6b 100644
--- a/src/transformers/models/imagegpt/configuration_imagegpt.py
+++ b/src/transformers/models/imagegpt/configuration_imagegpt.py
@@ -32,7 +32,7 @@ class ImageGPTConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`ImageGPTModel`] or a [`TFImageGPTModel`]. It is
used to instantiate a GPT-2 model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the ImageGPT
- [small](https://huggingface.co/imagegpt) architecture.
+ [openai/imagegpt-small](https://huggingface.co/openai/imagegpt-small) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py
index 5866744bd8d0a8..4a18c64a13dd22 100755
--- a/src/transformers/models/imagegpt/modeling_imagegpt.py
+++ b/src/transformers/models/imagegpt/modeling_imagegpt.py
@@ -200,7 +200,8 @@ def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
@@ -699,14 +700,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
@@ -1000,7 +1001,7 @@ def forward(
>>> samples = output[:, 1:].cpu().detach().numpy()
>>> samples_img = [
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples
- >>> ] # convert color cluster tokens back to pixels
+ ... ] # convert color cluster tokens back to pixels
>>> f, axes = plt.subplots(1, batch_size, dpi=300)
>>> for img, ax in zip(samples_img, axes):
@@ -1010,14 +1011,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
@@ -1086,7 +1087,7 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
@@ -1143,14 +1144,14 @@ def forward(
if "pixel_values" in kwargs:
warnings.warn(
- "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+ "The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids`"
+ " instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
- "You cannot pass both `pixel_values` and `input_ids`. "
- "Please make sure to only pass `input_ids`."
+ "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
diff --git a/src/transformers/models/layoutlm/__init__.py b/src/transformers/models/layoutlm/__init__.py
index b77edddc4d7e6f..a7ccae38e89e19 100644
--- a/src/transformers/models/layoutlm/__init__.py
+++ b/src/transformers/models/layoutlm/__init__.py
@@ -18,9 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
-from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
-from .tokenization_layoutlm import LayoutLMTokenizer
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -28,10 +32,20 @@
"tokenization_layoutlm": ["LayoutLMTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutlm_fast"] = ["LayoutLMTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_layoutlm"] = [
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMForMaskedLM",
@@ -41,7 +55,12 @@
"LayoutLMPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_layoutlm"] = [
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM",
@@ -57,10 +76,20 @@
from .configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMOnnxConfig
from .tokenization_layoutlm import LayoutLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutlm_fast import LayoutLMTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
@@ -69,7 +98,12 @@
LayoutLMModel,
LayoutLMPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py
index 9b77b2ce3f930b..94100791d39ff2 100644
--- a/src/transformers/models/layoutlm/configuration_layoutlm.py
+++ b/src/transformers/models/layoutlm/configuration_layoutlm.py
@@ -27,8 +27,12 @@
logger = logging.get_logger(__name__)
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/config.json",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/config.json",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/config.json"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py
index b198c183d985c0..2a48ba5f4fd5c4 100644
--- a/src/transformers/models/layoutlm/modeling_layoutlm.py
+++ b/src/transformers/models/layoutlm/modeling_layoutlm.py
@@ -154,7 +154,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -398,7 +398,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -799,7 +800,7 @@ def forward(
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if bbox is None:
- bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+ bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
diff --git a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
index b184cb352e202b..d15fc29b7366d1 100644
--- a/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
+++ b/src/transformers/models/layoutlm/modeling_tf_layoutlm.py
@@ -453,8 +453,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm.py b/src/transformers/models/layoutlm/tokenization_layoutlm.py
index 6ef9a9c3a00590..1cd0a5f6e087a1 100644
--- a/src/transformers/models/layoutlm/tokenization_layoutlm.py
+++ b/src/transformers/models/layoutlm/tokenization_layoutlm.py
@@ -25,8 +25,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py b/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
index 90ba0a94feabb4..a614c3e61559dd 100644
--- a/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
+++ b/src/transformers/models/layoutlm/tokenization_layoutlm_fast.py
@@ -26,12 +26,20 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "microsoft/layoutlm-base-uncased": "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json",
- "microsoft/layoutlm-large-uncased": "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json",
+ "microsoft/layoutlm-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-base-uncased/resolve/main/tokenizer.json"
+ ),
+ "microsoft/layoutlm-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlm-large-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/layoutlmv2/__init__.py b/src/transformers/models/layoutlmv2/__init__.py
index 9f7a8dae39acf2..beaacb815843d0 100644
--- a/src/transformers/models/layoutlmv2/__init__.py
+++ b/src/transformers/models/layoutlmv2/__init__.py
@@ -18,22 +18,43 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
_import_structure = {
"configuration_layoutlmv2": ["LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv2Config"],
+ "processing_layoutlmv2": ["LayoutLMv2Processor"],
"tokenization_layoutlmv2": ["LayoutLMv2Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutlmv2_fast"] = ["LayoutLMv2TokenizerFast"]
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_layoutlmv2"] = ["LayoutLMv2FeatureExtractor"]
- _import_structure["processing_layoutlmv2"] = ["LayoutLMv2Processor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_layoutlmv2"] = [
"LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"LayoutLMv2ForQuestionAnswering",
@@ -46,16 +67,31 @@
if TYPE_CHECKING:
from .configuration_layoutlmv2 import LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv2Config
+ from .processing_layoutlmv2 import LayoutLMv2Processor
from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutlmv2_fast import LayoutLMv2TokenizerFast
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor
- from .processing_layoutlmv2 import LayoutLMv2Processor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_layoutlmv2 import (
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMv2ForQuestionAnswering,
diff --git a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
index 12fe27f1a17eba..5b5cff29046a67 100644
--- a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
@@ -46,11 +46,11 @@ def normalize_box(box, width, height):
]
-def apply_tesseract(image: Image.Image, lang: Optional[str]):
+def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
# apply OCR
- data = pytesseract.image_to_data(image, lang=lang, output_type="dict")
+ data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
# filter empty words and corresponding coordinates
@@ -103,6 +103,8 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
ocr_lang (`Optional[str]`, *optional*):
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
used.
+ tesseract_config (`Optional[str]`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling Tesseract. For example: '--psm 6'.
@@ -112,13 +114,23 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
model_input_names = ["pixel_values"]
- def __init__(self, do_resize=True, size=224, resample=Image.BILINEAR, apply_ocr=True, ocr_lang=None, **kwargs):
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BILINEAR,
+ apply_ocr=True,
+ ocr_lang=None,
+ tesseract_config="",
+ **kwargs
+ ):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.apply_ocr = apply_ocr
self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
@@ -201,7 +213,7 @@ def __call__(
words_batch = []
boxes_batch = []
for image in images:
- words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang)
+ words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
words_batch.append(words)
boxes_batch.append(boxes)
diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
index 269e951ea00de4..7faa34eec430f0 100755
--- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
@@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch LayoutLMv2 model."""
-
import math
from typing import Optional, Tuple, Union
@@ -821,24 +820,35 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
- Returns:
+ Return:
Examples:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2Model
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2Model, set_seed
>>> from PIL import Image
+ >>> import torch
+ >>> from datasets import load_dataset
+
+ >>> set_seed(88)
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image_path = dataset["test"][0]["file"]
+ >>> image = Image.open(image_path).convert("RGB")
>>> encoding = processor(image, return_tensors="pt")
>>> outputs = model(**encoding)
>>> last_hidden_states = outputs.last_hidden_state
- ```"""
+
+ >>> last_hidden_states.shape
+ torch.Size([1, 342, 768])
+ ```
+ """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -990,25 +1000,37 @@ def forward(
Returns:
- Examples:
+ Example:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForSequenceClassification
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForSequenceClassification, set_seed
>>> from PIL import Image
>>> import torch
+ >>> from datasets import load_dataset
- >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> model = LayoutLMv2ForSequenceClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> set_seed(88)
+
+ >>> dataset = load_dataset("rvl_cdip", split="train", streaming=True)
+ >>> data = next(iter(dataset))
+ >>> image = data["image"].convert("RGB")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+ >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
+ ... )
>>> encoding = processor(image, return_tensors="pt")
- >>> sequence_label = torch.tensor([1])
+ >>> sequence_label = torch.tensor([data["label"]])
>>> outputs = model(**encoding, labels=sequence_label)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
+
+ >>> loss, logits = outputs.loss, outputs.logits
+ >>> predicted_idx = logits.argmax(dim=-1).item()
+ >>> predicted_answer = dataset.info.features["label"].names[4]
+ >>> predicted_idx, predicted_answer
+ (4, 'advertisement')
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1157,26 +1179,48 @@ def forward(
Returns:
- Examples:
+ Example:
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification, set_seed
>>> from PIL import Image
+ >>> from datasets import load_dataset
- >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
- >>> model = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")
+ >>> set_seed(88)
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
- >>> words = ["hello", "world"]
- >>> boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes
- >>> word_labels = [0, 1]
+ >>> datasets = load_dataset("nielsr/funsd", split="test")
+ >>> labels = datasets.features["ner_tags"].feature.names
+ >>> id2label = {v: k for v, k in enumerate(labels)}
- >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+ >>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
+ >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
+ ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
+ ... )
+
+ >>> data = datasets[0]
+ >>> image = Image.open(data["image_path"]).convert("RGB")
+ >>> words = data["words"]
+ >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
+ >>> word_labels = data["ner_tags"]
+ >>> encoding = processor(
+ ... image,
+ ... words,
+ ... boxes=boxes,
+ ... word_labels=word_labels,
+ ... padding="max_length",
+ ... truncation=True,
+ ... return_tensors="pt",
+ ... )
>>> outputs = model(**encoding)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
+ >>> logits, loss = outputs.logits, outputs.loss
+
+ >>> predicted_token_class_ids = logits.argmax(-1)
+ >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
+ >>> predicted_tokens_classes[:5]
+ ['B-ANSWER', 'B-HEADER', 'B-HEADER', 'B-HEADER', 'B-HEADER']
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1273,28 +1317,49 @@ def forward(
Returns:
- Examples:
+ Example:
+
+ In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
+ a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
```python
- >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForQuestionAnswering
- >>> from PIL import Image
+ >>> from transformers import LayoutLMv2Processor, LayoutLMv2ForQuestionAnswering, set_seed
>>> import torch
+ >>> from PIL import Image
+ >>> from datasets import load_dataset
+ >>> set_seed(88)
>>> processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
>>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
- >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
- >>> question = "what's his name?"
-
+ >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
+ >>> image_path = dataset["test"][0]["file"]
+ >>> image = Image.open(image_path).convert("RGB")
+ >>> question = "When is coffee break?"
>>> encoding = processor(image, question, return_tensors="pt")
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
-
- >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- >>> start_scores = outputs.start_logits
- >>> end_scores = outputs.end_logits
- ```"""
+
+ >>> outputs = model(**encoding)
+ >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
+ >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
+ >>> predicted_start_idx, predicted_end_idx
+ (154, 287)
+
+ >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
+ >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
+ >>> predicted_answer # results are not very good without further fine-tuning
+ 'council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public ...
+ ```
+
+ ```python
+ >>> target_start_index = torch.tensor([7])
+ >>> target_end_index = torch.tensor([14])
+ >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
+ >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
+ >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
+ >>> predicted_answer_span_start, predicted_answer_span_end
+ (154, 287)
+ ```
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
diff --git a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
index 449eb4770aafc4..57f0b78aed1bb1 100644
--- a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py
@@ -86,10 +86,12 @@ def __call__(
if self.feature_extractor.apply_ocr and (word_labels is not None):
raise ValueError(
- "You cannot provide word labels "
- "if you initialized the feature extractor with apply_ocr set to True."
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
)
+ if return_overflowing_tokens is True and return_offsets_mapping is False:
+ raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
+
# first, apply the feature extractor
features = self.feature_extractor(images=images, return_tensors=return_tensors)
@@ -122,6 +124,37 @@ def __call__(
)
# add pixel values
- encoded_inputs["image"] = features.pop("pixel_values")
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["image"] = images
return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
index b750ede1850b90..d2cf0b2f3dceee 100644
--- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
@@ -38,8 +38,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt",
- "microsoft/layoutlmv2-large-uncased": "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt"
+ ),
+ "microsoft/layoutlmv2-large-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/vocab.txt"
+ ),
}
}
@@ -255,8 +259,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -495,20 +499,23 @@ def _is_valid_text_input(t):
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
words = text if text_pair is None else text_pair
- assert boxes is not None, "You must provide corresponding bounding boxes"
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
if is_batched:
- assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
for words_example, boxes_example in zip(words, boxes):
- assert len(words_example) == len(
- boxes_example
- ), "You must provide as many words as there are bounding boxes"
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
else:
- assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
@@ -1200,16 +1207,17 @@ def truncate_sequences(
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
- error_msg + "Please select another truncation strategy than "
+ error_msg
+ + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
- f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
- f"truncation strategy. So the returned list will always be empty even if some "
- f"tokens have been removed."
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
@@ -1231,7 +1239,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (
diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
index 2cc0de63add026..b61cf5ef7633ad 100644
--- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
+++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py
@@ -47,10 +47,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/tokenizer.json",
+ "microsoft/layoutlmv2-base-uncased": (
+ "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
@@ -256,20 +260,23 @@ def _is_valid_text_input(t):
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
words = text if text_pair is None else text_pair
- assert boxes is not None, "You must provide corresponding bounding boxes"
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
if is_batched:
- assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
for words_example, boxes_example in zip(words, boxes):
- assert len(words_example) == len(
- boxes_example
- ), "You must provide as many words as there are bounding boxes"
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
else:
- assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
diff --git a/src/transformers/models/layoutlmv3/__init__.py b/src/transformers/models/layoutlmv3/__init__.py
new file mode 100644
index 00000000000000..a7d104040bd6d1
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/__init__.py
@@ -0,0 +1,107 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {
+ "configuration_layoutlmv3": ["LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv3Config"],
+ "processing_layoutlmv3": ["LayoutLMv3Processor"],
+ "tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_layoutlmv3_fast"] = ["LayoutLMv3TokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_layoutlmv3"] = [
+ "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LayoutLMv3ForQuestionAnswering",
+ "LayoutLMv3ForSequenceClassification",
+ "LayoutLMv3ForTokenClassification",
+ "LayoutLMv3Model",
+ "LayoutLMv3PreTrainedModel",
+ ]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_layoutlmv3"] = ["LayoutLMv3FeatureExtractor"]
+
+
+if TYPE_CHECKING:
+ from .configuration_layoutlmv3 import LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv3Config
+ from .processing_layoutlmv3 import LayoutLMv3Processor
+ from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_layoutlmv3 import (
+ LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ LayoutLMv3PreTrainedModel,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_layoutlmv3 import LayoutLMv3FeatureExtractor
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
new file mode 100644
index 00000000000000..ebde107947a142
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
@@ -0,0 +1,178 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LayoutLMv3 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
+}
+
+
+class LayoutLMv3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LayoutLMv3Model`]. It is used to instantiate an
+ LayoutLMv3 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the LayoutLMv3
+ [microsoft/layoutlmv3-base](https://huggingface.co/microsoft/layoutlmv3-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the LayoutLMv3 model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`LayoutLMv3Model`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv3Model`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum value that the 2D position embedding might ever be used with. Typically set this to something
+ large just in case (e.g., 1024).
+ coordinate_size (`int`, *optional*, defaults to `128`):
+ Dimension of the coordinate embeddings.
+ shape_size (`int`, *optional*, defaults to `128`):
+ Dimension of the width and height embeddings.
+ has_relative_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a relative attention bias in the self-attention mechanism.
+ rel_pos_bins (`int`, *optional*, defaults to 32):
+ The number of relative position bins to be used in the self-attention mechanism.
+ max_rel_pos (`int`, *optional*, defaults to 128):
+ The maximum number of relative positions to be used in the self-attention mechanism.
+ max_rel_2d_pos (`int`, *optional*, defaults to 256):
+ The maximum number of relative 2D positions in the self-attention mechanism.
+ rel_2d_pos_bins (`int`, *optional*, defaults to 64):
+ The number of 2D relative position bins in the self-attention mechanism.
+ has_spatial_attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use a spatial attention bias in the self-attention mechanism.
+ visual_embed (`bool`, *optional*, defaults to `True`):
+ Whether or not to add patch embeddings.
+ input_size (`int`, *optional*, defaults to `224`):
+ The size (resolution) of the images.
+ num_channels (`int`, *optional*, defaults to `3`):
+ The number of channels of the images.
+ patch_size (`int`, *optional*, defaults to `16`)
+ The size (resolution) of the patches.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Example:
+
+ ```python
+ >>> from transformers import LayoutLMv3Model, LayoutLMv3Config
+
+ >>> # Initializing a LayoutLMv3 microsoft/layoutlmv3-base style configuration
+ >>> configuration = LayoutLMv3Config()
+
+ >>> # Initializing a model from the microsoft/layoutlmv3-base style configuration
+ >>> model = LayoutLMv3Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "layoutlmv3"
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ max_2d_position_embeddings=1024,
+ coordinate_size=128,
+ shape_size=128,
+ has_relative_attention_bias=True,
+ rel_pos_bins=32,
+ max_rel_pos=128,
+ rel_2d_pos_bins=64,
+ max_rel_2d_pos=256,
+ has_spatial_attention_bias=True,
+ text_embed=True,
+ visual_embed=True,
+ input_size=224,
+ num_channels=3,
+ patch_size=16,
+ classifier_dropout=None,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ hidden_act=hidden_act,
+ hidden_dropout_prob=hidden_dropout_prob,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size,
+ initializer_range=initializer_range,
+ layer_norm_eps=layer_norm_eps,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.max_2d_position_embeddings = max_2d_position_embeddings
+ self.coordinate_size = coordinate_size
+ self.shape_size = shape_size
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.rel_pos_bins = rel_pos_bins
+ self.max_rel_pos = max_rel_pos
+ self.has_spatial_attention_bias = has_spatial_attention_bias
+ self.rel_2d_pos_bins = rel_2d_pos_bins
+ self.max_rel_2d_pos = max_rel_2d_pos
+ self.text_embed = text_embed
+ self.visual_embed = visual_embed
+ self.input_size = input_size
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.classifier_dropout = classifier_dropout
diff --git a/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
new file mode 100644
index 00000000000000..6f2d54529ba971
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
@@ -0,0 +1,242 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for LayoutLMv3.
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
+
+
+# soft dependency
+if is_pytesseract_available():
+ import pytesseract
+
+logger = logging.get_logger(__name__)
+
+ImageInput = Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+]
+
+
+def normalize_box(box, width, height):
+ return [
+ int(1000 * (box[0] / width)),
+ int(1000 * (box[1] / height)),
+ int(1000 * (box[2] / width)),
+ int(1000 * (box[3] / height)),
+ ]
+
+
+def apply_tesseract(image: Image.Image, lang: Optional[str]):
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
+
+ # apply OCR
+ data = pytesseract.image_to_data(image, lang=lang, output_type="dict")
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
+
+ # filter empty words and corresponding coordinates
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
+
+ # turn coordinates into (left, top, left+width, top+height) format
+ actual_boxes = []
+ for x, y, w, h in zip(left, top, width, height):
+ actual_box = [x, y, x + w, y + h]
+ actual_boxes.append(actual_box)
+
+ image_width, image_height = image.size
+
+ # finally, normalize the bounding boxes
+ normalized_boxes = []
+ for box in actual_boxes:
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
+
+ assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
+
+ return words, normalized_boxes
+
+
+class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
+ apply OCR on them in order to get a list of words and normalized bounding boxes.
+
+ This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
+ of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int` or `Tuple(int)`, *optional*, defaults to 224):
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
+ integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
+ set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
+ ocr_lang (`Optional[str]`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used.
+
+
+
+ LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
+
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BILINEAR,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ apply_ocr=True,
+ ocr_lang=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.apply_ocr = apply_ocr
+ self.ocr_lang = ocr_lang
+
+ def __call__(
+ self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ - **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
+ initialized with `apply_ocr` set to `True`).
+ - **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
+ (only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
+
+ Examples:
+
+ ```python
+ >>> from transformers import LayoutLMv3FeatureExtractor
+ >>> from PIL import Image
+
+ >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
+
+ >>> # option 1: with apply_ocr=True (default)
+ >>> feature_extractor = LayoutLMv3FeatureExtractor()
+ >>> encoding = feature_extractor(image, return_tensors="pt")
+ >>> print(encoding.keys())
+ >>> # dict_keys(['pixel_values', 'words', 'boxes'])
+
+ >>> # option 2: with apply_ocr=False
+ >>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ >>> encoding = feature_extractor(image, return_tensors="pt")
+ >>> print(encoding.keys())
+ >>> # dict_keys(['pixel_values'])
+ ```"""
+
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
+ f"but is of type {type(images)}."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # Tesseract OCR to get words + normalized bounding boxes
+ if self.apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang)
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ # transformations (resizing + normalization)
+ if self.do_resize and self.size is not None:
+ images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if self.apply_ocr:
+ encoded_inputs["words"] = words_batch
+ encoded_inputs["boxes"] = boxes_batch
+
+ return encoded_inputs
diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
new file mode 100644
index 00000000000000..f301afafd80ec8
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
@@ -0,0 +1,1309 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LayoutLMv3 model."""
+
+import collections
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers import apply_chunking_to_forward
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from ...activations import ACT2FN
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from .configuration_layoutlmv3 import LayoutLMv3Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LayoutLMv3Config"
+
+LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/layoutlmv3-base",
+ "microsoft/layoutlmv3-large",
+ # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
+]
+
+LAYOUTLMV3_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`LayoutLMv2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LAYOUTLMV3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `{0}`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`LayoutLMv2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
+ Bounding boxes of each input sequence tokens. Selected in the range `[0,
+ config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
+ format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
+ y1) represents the position of the lower right corner.
+
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Batch of document images.
+
+ attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `{0}`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class LayoutLMv3PatchEmbeddings(nn.Module):
+ """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
+ image sizes."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ image_size = (
+ config.input_size
+ if isinstance(config.input_size, collections.abc.Iterable)
+ else (config.input_size, config.input_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+ self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values, position_embedding=None):
+ embeddings = self.proj(pixel_values)
+
+ if position_embedding is not None:
+ # interpolate the position embedding to the corresponding size
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
+ position_embedding = position_embedding.permute(0, 3, 1, 2)
+ patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+ position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
+ embeddings = embeddings + position_embedding
+
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class LayoutLMv3TextEmbeddings(nn.Module):
+ """
+ LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+
+ def calculate_spatial_position_embeddings(self, bbox):
+ try:
+ left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
+ upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
+ right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
+ lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
+ except IndexError as e:
+ raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
+
+ h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
+ w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
+
+ # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
+ spatial_position_embeddings = torch.cat(
+ [
+ left_position_embeddings,
+ upper_position_embeddings,
+ right_position_embeddings,
+ lower_position_embeddings,
+ h_position_embeddings,
+ w_position_embeddings,
+ ],
+ dim=-1,
+ )
+ return spatial_position_embeddings
+
+ def create_position_ids_from_input_ids(self, input_ids, padding_idx):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
+ return incremental_indices.long() + padding_idx
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
+ input_ids.device
+ )
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
+
+ embeddings = embeddings + spatial_position_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class LayoutLMv3PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LayoutLMv3Config
+ base_model_prefix = "layoutlmv3"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class LayoutLMv3SelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def cogview_attention(self, attention_scores, alpha=32):
+ """
+ https://arxiv.org/abs/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
+ (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
+ will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
+ cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
+ """
+ scaled_attention_scores = attention_scores / alpha
+ max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
+ new_attention_scores = (scaled_attention_scores - max_value) * alpha
+ return nn.Softmax(dim=-1)(new_attention_scores)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # The attention scores QT K/ād could be significantly larger than input elements, and result in overflow.
+ # Changing the computational order into QT(K/ād) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
+ attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
+
+ if self.has_relative_attention_bias and self.has_spatial_attention_bias:
+ attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
+ elif self.has_relative_attention_bias:
+ attention_scores += rel_pos / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ # Use the trick of the CogView paper to stablize training
+ attention_probs = self.cogview_attention(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
+class LayoutLMv3SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
+class LayoutLMv3Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = LayoutLMv3SelfAttention(config)
+ self.output = LayoutLMv3SelfOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
+class LayoutLMv3Layer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = LayoutLMv3Attention(config)
+ self.intermediate = LayoutLMv3Intermediate(config)
+ self.output = LayoutLMv3Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class LayoutLMv3Encoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ if self.has_relative_attention_bias:
+ self.rel_pos_bins = config.rel_pos_bins
+ self.max_rel_pos = config.max_rel_pos
+ self.rel_pos_onehot_size = config.rel_pos_bins
+ self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
+
+ if self.has_spatial_attention_bias:
+ self.max_rel_2d_pos = config.max_rel_2d_pos
+ self.rel_2d_pos_bins = config.rel_2d_pos_bins
+ self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
+ self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+ self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+
+ def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ ret = 0
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).long() * num_buckets
+ n = torch.abs(relative_position)
+ else:
+ n = torch.max(-relative_position, torch.zeros_like(relative_position))
+ # now n is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).to(torch.long)
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def _cal_1d_pos_emb(self, hidden_states, position_ids):
+ rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
+
+ rel_pos = self.relative_position_bucket(
+ rel_pos_mat,
+ num_buckets=self.rel_pos_bins,
+ max_distance=self.max_rel_pos,
+ )
+ rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
+ rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
+ rel_pos = rel_pos.contiguous()
+ return rel_pos
+
+ def _cal_2d_pos_emb(self, hidden_states, bbox):
+ position_coord_x = bbox[:, :, 0]
+ position_coord_y = bbox[:, :, 3]
+ rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
+ rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
+ rel_pos_x = self.relative_position_bucket(
+ rel_pos_x_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_y = self.relative_position_bucket(
+ rel_pos_y_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+ rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+ rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
+ rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
+ rel_pos_x = rel_pos_x.contiguous()
+ rel_pos_y = rel_pos_y.contiguous()
+ rel_2d_pos = rel_pos_x + rel_pos_y
+ return rel_2d_pos
+
+ def forward(
+ self,
+ hidden_states,
+ bbox=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ position_ids=None,
+ patch_height=None,
+ patch_width=None,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
+ rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
+ # The above line will cause error:
+ # RuntimeError: Trying to backward through the graph a second time
+ # (or directly access saved tensors after they have already been freed).
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ rel_pos,
+ rel_2d_pos,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
+class LayoutLMv3Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
+class LayoutLMv3Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+@add_start_docstrings(
+ "The bare LayoutLMv3 Model transformer outputting raw hidden-states without any specific head on top.",
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ if config.text_embed:
+ self.embeddings = LayoutLMv3TextEmbeddings(config)
+
+ if config.visual_embed:
+ # use the default pre-training parameters for fine-tuning (e.g., input_size)
+ # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
+ self.patch_embed = LayoutLMv3PatchEmbeddings(config)
+
+ size = int(config.input_size / config.patch_size)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ self.init_visual_bbox(image_size=(size, size))
+
+ self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
+
+ self.encoder = LayoutLMv3Encoder(config)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
+ """
+ Create the bounding boxes for the visual (patch) tokens.
+ """
+ visual_bbox_x = torch.div(
+ torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
+ )
+ visual_bbox_y = torch.div(
+ torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
+ )
+ visual_bbox = torch.stack(
+ [
+ visual_bbox_x[:-1].repeat(image_size[0], 1),
+ visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
+ visual_bbox_x[1:].repeat(image_size[0], 1),
+ visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
+ ],
+ dim=-1,
+ ).view(-1, 4)
+
+ cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
+ self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
+
+ def calculate_visual_bbox(self, device, dtype, batch_size):
+ visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
+ visual_bbox = visual_bbox.to(device).type(dtype)
+ return visual_bbox
+
+ def forward_image(self, pixel_values):
+ embeddings = self.patch_embed(pixel_values)
+
+ # add [CLS] token
+ batch_size, seq_len, _ = embeddings.size()
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add position embeddings
+ if self.pos_embed is not None:
+ embeddings = embeddings + self.pos_embed
+
+ embeddings = self.pos_drop(embeddings)
+ embeddings = self.norm(embeddings)
+
+ return embeddings
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif pixel_values is not None:
+ batch_size = len(pixel_values)
+ device = pixel_values.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
+
+ if input_ids is not None or inputs_embeds is not None:
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+ if bbox is None:
+ bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ final_bbox = final_position_ids = None
+ patch_height = patch_width = None
+ if pixel_values is not None:
+ patch_height, patch_width = int(pixel_values.shape[2] / self.config.patch_size), int(
+ pixel_values.shape[3] / self.config.patch_size
+ )
+ visual_embeddings = self.forward_image(pixel_values)
+ visual_attention_mask = torch.ones(
+ (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
+ else:
+ attention_mask = visual_attention_mask
+
+ if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ if self.config.has_spatial_attention_bias:
+ visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
+ if bbox is not None:
+ final_bbox = torch.cat([bbox, visual_bbox], dim=1)
+ else:
+ final_bbox = visual_bbox
+
+ visual_position_ids = torch.arange(
+ 0, visual_embeddings.shape[1], dtype=torch.long, device=device
+ ).repeat(batch_size, 1)
+ if input_ids is not None or inputs_embeds is not None:
+ position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
+ position_ids = position_ids.expand(input_shape)
+ final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
+ else:
+ final_position_ids = visual_position_ids
+
+ if input_ids is not None or inputs_embeds is not None:
+ embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
+ else:
+ embedding_output = visual_embeddings
+
+ embedding_output = self.LayerNorm(embedding_output)
+ embedding_output = self.dropout(embedding_output)
+ elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+ if self.config.has_spatial_attention_bias:
+ final_bbox = bbox
+ if self.config.has_relative_attention_bias:
+ position_ids = self.embeddings.position_ids[:, : input_shape[1]]
+ position_ids = position_ids.expand_as(input_ids)
+ final_position_ids = position_ids
+
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ bbox=final_bbox,
+ position_ids=final_position_ids,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ patch_height=patch_height,
+ patch_width=patch_width,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class LayoutLMv3ClassificationHead(nn.Module):
+ """
+ Head for sentence-level classification tasks. Reference: RobertaClassificationHead
+ """
+
+ def __init__(self, config, pool_feature=False):
+ super().__init__()
+ self.pool_feature = pool_feature
+ if pool_feature:
+ self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
+ else:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
+ for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
+ [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
+ [Kleister-NDA](https://github.com/applicaai/kleister-nda).
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if config.num_labels < 10:
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ else:
+ self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ pixel_values=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForTokenClassification
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+ >>> word_labels = example["ner_tags"]
+
+ >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+
+ >>> outputs = model(**encoding)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ bbox=bbox,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ pixel_values=pixel_values,
+ )
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+ # only take the text part of the output representations
+ sequence_output = outputs[0][:, :seq_length]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a span classification head on top for extractive question-answering tasks such as
+ [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
+ compute `span start logits` and `span end logits`).
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ bbox=None,
+ pixel_values=None,
+ ):
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
+ >>> from datasets import load_dataset
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> question = "what's his name?"
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
+ >>> start_positions = torch.tensor([1])
+ >>> end_positions = torch.tensor([3])
+
+ >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
+ >>> loss = outputs.loss
+ >>> start_scores = outputs.start_logits
+ >>> end_scores = outputs.end_logits
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
+ [CLS] token) e.g. for document image classification tasks such as the
+ [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
+ """,
+ LAYOUTLMV3_START_DOCSTRING,
+)
+class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.layoutlmv3 = LayoutLMv3Model(config)
+ self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ bbox=None,
+ pixel_values=None,
+ ):
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
+ >>> from datasets import load_dataset
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
+ >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
+
+ >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
+ >>> example = dataset[0]
+ >>> image = example["image"]
+ >>> words = example["tokens"]
+ >>> boxes = example["bboxes"]
+
+ >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
+ >>> sequence_label = torch.tensor([1])
+
+ >>> outputs = model(**encoding, labels=sequence_label)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.layoutlmv3(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ )
+
+ sequence_output = outputs[0][:, 0, :]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py
new file mode 100644
index 00000000000000..c80b2bd5f2030d
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py
@@ -0,0 +1,158 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for LayoutLMv3.
+"""
+from typing import List, Optional, Union
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from ...utils import TensorType
+
+
+class LayoutLMv3Processor(ProcessorMixin):
+ r"""
+ Constructs a LayoutLMv3 processor which combines a LayoutLMv3 feature extractor and a LayoutLMv3 tokenizer into a
+ single processor.
+
+ [`LayoutLMv3Processor`] offers all the functionalities you need to prepare data for the model.
+
+ It first uses [`LayoutLMv3FeatureExtractor`] to resize and normalize document images, and optionally applies OCR to
+ get words and normalized bounding boxes. These are then provided to [`LayoutLMv3Tokenizer`] or
+ [`LayoutLMv3TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`,
+ `attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned
+ into token-level `labels` for token classification tasks (such as FUNSD, CORD).
+
+ Args:
+ feature_extractor (`LayoutLMv3FeatureExtractor`):
+ An instance of [`LayoutLMv3FeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`):
+ An instance of [`LayoutLMv3Tokenizer`] or [`LayoutLMv3TokenizerFast`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "LayoutLMv3FeatureExtractor"
+ tokenizer_class = ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast")
+
+ def __call__(
+ self,
+ images,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ This method first forwards the `images` argument to [`~LayoutLMv3FeatureExtractor.__call__`]. In case
+ [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
+ bounding boxes along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output,
+ together with resized and normalized `pixel_values`. In case [`LayoutLMv3FeatureExtractor`] was initialized
+ with `apply_ocr` set to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user
+ along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output, together with
+ resized and normalized `pixel_values`.
+
+ Please refer to the docstring of the above two methods for more information.
+ """
+ # verify input
+ if self.feature_extractor.apply_ocr and (boxes is not None):
+ raise ValueError(
+ "You cannot provide bounding boxes "
+ "if you initialized the feature extractor with apply_ocr set to True."
+ )
+
+ if self.feature_extractor.apply_ocr and (word_labels is not None):
+ raise ValueError(
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
+ )
+
+ # first, apply the feature extractor
+ features = self.feature_extractor(images=images, return_tensors=return_tensors)
+
+ # second, apply the tokenizer
+ if text is not None and self.feature_extractor.apply_ocr and text_pair is None:
+ if isinstance(text, str):
+ text = [text] # add batch dimension (as the feature extractor always adds a batch dimension)
+ text_pair = features["words"]
+
+ encoded_inputs = self.tokenizer(
+ text=text if text is not None else features["words"],
+ text_pair=text_pair if text_pair is not None else None,
+ boxes=boxes if boxes is not None else features["boxes"],
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ # add pixel values
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["pixel_values"] = images
+
+ return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
new file mode 100644
index 00000000000000..b01e70ffb03713
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
@@ -0,0 +1,1478 @@
+# coding=utf-8
+# Copyright The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for LayoutLMv3. Same as LayoutLMv2, but RoBERTa-like BPE tokenization instead of WordPiece."""
+
+import json
+import os
+from functools import lru_cache
+from typing import Dict, List, Optional, Tuple, Union
+
+import regex as re
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PreTokenizedInput,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "merges_file": "merges.txt",
+}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json",
+ },
+ "merges_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "microsoft/layoutlmv3-base": 512,
+ "microsoft/layoutlmv3-large": 512,
+}
+
+
+LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+
+LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
+ `None`, this will use the predefined model maximum length if a maximum length is required by one of the
+ truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
+ truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+
+@lru_cache()
+# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("Ā”"), ord("Ā¬") + 1)) + list(range(ord("Ā®"), ord("Ćæ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+# Copied from transformers.models.roberta.tokenization_roberta.get_pairs
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class LayoutLMv3Tokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a LayoutLMv3 tokenizer. Based on [`RoBERTatokenizer`] (Byte Pair Encoding or BPE).
+ [`LayoutLMv3Tokenizer`] can be used to turn words, word-level bounding boxes and optional word labels to
+ token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`, and optional `labels` (for token
+ classification).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ [`LayoutLMv3Tokenizer`] runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the
+ word-level bounding boxes into token-level bounding boxes.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
+ cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ """
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask", "bbox"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=True,
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[0, 0, 0, 0],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ **kwargs
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ super().__init__(
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ **kwargs,
+ )
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @property
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size
+ def vocab_size(self):
+ return len(self.encoder)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A RoBERTa sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not
+ make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ # If the text starts with a token that should not be split, no space is added before the text in any case.
+ # It's necessary to match the fast tokenization
+ if (
+ (is_split_into_words or add_prefix_space)
+ and (len(text) > 0 and not text[0].isspace())
+ and sum([text.startswith(no_split_token) for no_split_token in self.unique_no_split_tokens]) == 0
+ ):
+ text = " " + text
+ return (text, kwargs)
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.__call__
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.batch_encode_plus
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_encode_plus
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ batch_outputs = self._batch_prepare_for_model(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_prepare_for_model
+ def _batch_prepare_for_model(
+ self,
+ batch_text_or_text_pairs,
+ is_pair: bool = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+ manages a moving window (with user defined stride) for overflowing tokens.
+
+ Args:
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
+ """
+
+ batch_outputs = {}
+ for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)):
+ batch_text_or_text_pair, boxes_example = example
+ outputs = self.prepare_for_model(
+ batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair,
+ batch_text_or_text_pair[1] if is_pair else None,
+ boxes_example,
+ word_labels=word_labels[idx] if word_labels is not None else None,
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterward
+ return_attention_mask=False, # we pad in batch afterward
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode
+ def encode(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> List[int]:
+ encoded_inputs = self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode_plus
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._encode_plus
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ return self.prepare_for_model(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def prepare_for_model(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ prepend_batch_axis: bool = False,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens,
+ truncates sequences if overflowing while taking into account the special tokens and manages a moving window
+ (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and
+ *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a
+ combination of arguments will raise an error.
+
+ Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into
+ token-level `labels`. The word label is used for the first token of the word, while remaining tokens are
+ labeled with -100, such that they will be ignored by the loss function.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ tokens = []
+ pair_tokens = []
+ token_boxes = []
+ pair_token_boxes = []
+ labels = []
+
+ if text_pair is None:
+ if word_labels is None:
+ # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference)
+ for word, box in zip(text, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ else:
+ # CASE 2: token classification (training)
+ for word, box, label in zip(text, boxes, word_labels):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ tokens.extend(word_tokens)
+ token_boxes.extend([box] * len(word_tokens))
+ if self.only_label_first_subword:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1))
+ else:
+ labels.extend([label] * len(word_tokens))
+ else:
+ # CASE 3: document visual question answering (inference)
+ # text = question
+ # text_pair = words
+ tokens = self.tokenize(text)
+ token_boxes = [self.pad_token_box for _ in range(len(tokens))]
+
+ for word, box in zip(text_pair, boxes):
+ if len(word) < 1: # skip empty words
+ continue
+ word_tokens = self.tokenize(word)
+ pair_tokens.extend(word_tokens)
+ pair_token_boxes.extend([box] * len(word_tokens))
+
+ # Create ids + pair_ids
+ ids = self.convert_tokens_to_ids(tokens)
+ pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None
+
+ if (
+ return_overflowing_tokens
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
+ and pair_ids is not None
+ ):
+ raise ValueError(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+
+ # Compute the total size of the returned encodings
+ pair = bool(pair_ids is not None)
+ len_ids = len(ids)
+ len_pair_ids = len(pair_ids) if pair else 0
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+ # Truncation: Handle max sequence length
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+ (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ ) = self.truncate_sequences(
+ ids,
+ token_boxes,
+ pair_ids=pair_ids,
+ pair_token_boxes=pair_token_boxes,
+ labels=labels,
+ num_tokens_to_remove=total_len - max_length,
+ truncation_strategy=truncation_strategy,
+ stride=stride,
+ )
+
+ if return_token_type_ids and not add_special_tokens:
+ raise ValueError(
+ "Asking to return token_type_ids while setting add_special_tokens to False "
+ "results in an undefined behavior. Please set add_special_tokens to True or "
+ "set return_token_type_ids to None."
+ )
+
+ # Load from model defaults
+ if return_token_type_ids is None:
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ encoded_inputs = {}
+
+ if return_overflowing_tokens:
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
+ encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes
+ encoded_inputs["overflowing_labels"] = overflowing_labels
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+ # Add special tokens
+ if add_special_tokens:
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+ token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box]
+ if pair_token_boxes:
+ pair_token_boxes = [self.sep_token_box] + pair_token_boxes + [self.sep_token_box]
+ token_boxes = token_boxes + pair_token_boxes if pair else token_boxes
+ if labels:
+ labels = [self.pad_token_label] + labels + [self.pad_token_label]
+ else:
+ sequence = ids + pair_ids if pair else ids
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+ token_boxes = token_boxes + pair_token_boxes if pair else token_boxes
+
+ # Build output dictionary
+ encoded_inputs["input_ids"] = sequence
+ encoded_inputs["bbox"] = token_boxes
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = token_type_ids
+ if return_special_tokens_mask:
+ if add_special_tokens:
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+ else:
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+ if labels:
+ encoded_inputs["labels"] = labels
+
+ # Check lengths
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+ # Padding
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+ encoded_inputs = self.pad(
+ encoded_inputs,
+ max_length=max_length,
+ padding=padding_strategy.value,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ if return_length:
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+ batch_outputs = BatchEncoding(
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+ )
+
+ return batch_outputs
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.truncate_sequences
+ def truncate_sequences(
+ self,
+ ids: List[int],
+ token_boxes: List[List[int]],
+ pair_ids: Optional[List[int]] = None,
+ pair_token_boxes: Optional[List[List[int]]] = None,
+ labels: Optional[List[int]] = None,
+ num_tokens_to_remove: int = 0,
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+ stride: int = 0,
+ ) -> Tuple[List[int], List[int], List[int]]:
+ """
+ Truncates a sequence pair in-place following the strategy.
+
+ Args:
+ ids (`List[int]`):
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+ `convert_tokens_to_ids` methods.
+ token_boxes (`List[List[int]]`):
+ Bounding boxes of the first sequence.
+ pair_ids (`List[int]`, *optional*):
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+ and `convert_tokens_to_ids` methods.
+ pair_token_boxes (`List[List[int]]`, *optional*):
+ Bounding boxes of the second sequence.
+ labels (`List[int]`, *optional*):
+ Labels of the first sequence (for token classification tasks).
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
+ Number of tokens to remove using the truncation strategy.
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ The strategy to follow for truncation. Can be:
+
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+ batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+ than the model maximum admissible input size).
+ stride (`int`, *optional*, defaults to 0):
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+ sequence returned. The value of this argument defines the number of additional tokens.
+
+ Returns:
+ `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
+ of sequences (or a batch of pairs) is provided.
+ """
+ if num_tokens_to_remove <= 0:
+ return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], []
+
+ if not isinstance(truncation_strategy, TruncationStrategy):
+ truncation_strategy = TruncationStrategy(truncation_strategy)
+
+ overflowing_tokens = []
+ overflowing_token_boxes = []
+ overflowing_labels = []
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
+ ):
+ if len(ids) > num_tokens_to_remove:
+ window_len = min(len(ids), stride + num_tokens_to_remove)
+ overflowing_tokens = ids[-window_len:]
+ overflowing_token_boxes = token_boxes[-window_len:]
+ overflowing_labels = labels[-window_len:]
+ ids = ids[:-num_tokens_to_remove]
+ token_boxes = token_boxes[:-num_tokens_to_remove]
+ labels = labels[:-num_tokens_to_remove]
+ else:
+ error_msg = (
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the first sequence has a length {len(ids)}. "
+ )
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
+ error_msg = (
+ error_msg
+ + "Please select another truncation strategy than "
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
+ )
+ logger.error(error_msg)
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+ logger.warning(
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
+ )
+ for _ in range(num_tokens_to_remove):
+ if pair_ids is None or len(ids) > len(pair_ids):
+ ids = ids[:-1]
+ token_boxes = token_boxes[:-1]
+ labels = labels[:-1]
+ else:
+ pair_ids = pair_ids[:-1]
+ pair_token_boxes = pair_token_boxes[:-1]
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+ if len(pair_ids) > num_tokens_to_remove:
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+ overflowing_tokens = pair_ids[-window_len:]
+ overflowing_token_boxes = pair_token_boxes[-window_len:]
+ pair_ids = pair_ids[:-num_tokens_to_remove]
+ pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove]
+ else:
+ logger.error(
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
+ f"but the second sequence has a length {len(pair_ids)}. "
+ f"Please select another truncation strategy than {truncation_strategy}, "
+ "for instance 'longest_first' or 'only_first'."
+ )
+
+ return (
+ ids,
+ token_boxes,
+ pair_ids,
+ pair_token_boxes,
+ labels,
+ overflowing_tokens,
+ overflowing_token_boxes,
+ overflowing_labels,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._pad
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return encoded_inputs
diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py
new file mode 100644
index 00000000000000..be5f938dbf17ce
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py
@@ -0,0 +1,853 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fast tokenization class for LayoutLMv3. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus
+and _encode_plus, in which the Rust tokenizer is used.
+"""
+
+import json
+from typing import Dict, List, Optional, Tuple, Union
+
+from tokenizers import pre_tokenizers, processors
+
+from ...tokenization_utils_base import (
+ BatchEncoding,
+ EncodedInput,
+ PaddingStrategy,
+ PreTokenizedInput,
+ TensorType,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import add_end_docstrings, logging
+from .tokenization_layoutlmv3 import (
+ LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING,
+ LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+ LayoutLMv3Tokenizer,
+)
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json",
+ },
+ "merges_file": {
+ "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt",
+ "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "microsoft/layoutlmv3-base": 512,
+ "microsoft/layoutlmv3-large": 512,
+}
+
+
+class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" LayoutLMv3 tokenizer (backed by HuggingFace's *tokenizers* library). Based on BPE.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether the post processing step should trim offsets to avoid including whitespaces.
+ cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [CLS] token.
+ sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [SEP] token.
+ pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
+ The bounding box to use for the special [PAD] token.
+ pad_token_label (`int`, *optional*, defaults to -100):
+ The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
+ CrossEntropyLoss.
+ only_label_first_subword (`bool`, *optional*, defaults to `True`):
+ Whether or not to only label the first subword, in case word labels are provided.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = LayoutLMv3Tokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=True,
+ trim_offsets=True,
+ cls_token_box=[0, 0, 0, 0],
+ sep_token_box=[0, 0, 0, 0],
+ pad_token_box=[0, 0, 0, 0],
+ pad_token_label=-100,
+ only_label_first_subword=True,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ trim_offsets=trim_offsets,
+ cls_token_box=cls_token_box,
+ sep_token_box=sep_token_box,
+ pad_token_box=pad_token_box,
+ pad_token_label=pad_token_label,
+ only_label_first_subword=only_label_first_subword,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ tokenizer_component = "post_processor"
+ tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
+ if tokenizer_component_instance:
+ state = json.loads(tokenizer_component_instance.__getstate__())
+
+ # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
+ if "sep" in state:
+ state["sep"] = tuple(state["sep"])
+ if "cls" in state:
+ state["cls"] = tuple(state["cls"])
+
+ changes_to_apply = False
+
+ if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ state["add_prefix_space"] = add_prefix_space
+ changes_to_apply = True
+
+ if state.get("trim_offsets", trim_offsets) != trim_offsets:
+ state["trim_offsets"] = trim_offsets
+ changes_to_apply = True
+
+ if changes_to_apply:
+ component_class = getattr(processors, state.pop("type"))
+ new_value = component_class(**state)
+ setattr(self.backend_tokenizer, tokenizer_component, new_value)
+
+ # additional properties
+ self.cls_token_box = cls_token_box
+ self.sep_token_box = sep_token_box
+ self.pad_token_box = pad_token_box
+ self.pad_token_label = pad_token_label
+ self.only_label_first_subword = only_label_first_subword
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.__call__
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
+ text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
+ boxes: Union[List[List[int]], List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences with word-level normalized bounding boxes and optional labels.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
+ (words of a single example or questions of a batch of examples) or a list of list of strings (batch of
+ words).
+ text_pair (`List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
+ (pretokenized string).
+ boxes (`List[List[int]]`, `List[List[List[int]]]`):
+ Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
+ word_labels (`List[int]`, `List[List[int]]`, *optional*):
+ Word-level integer labels (for token classification tasks such as FUNSD, CORD).
+ """
+ # Input type checking for clearer error
+ def _is_valid_text_input(t):
+ if isinstance(t, str):
+ # Strings are fine
+ return True
+ elif isinstance(t, (list, tuple)):
+ # List are fine as long as they are...
+ if len(t) == 0:
+ # ... empty
+ return True
+ elif isinstance(t[0], str):
+ # ... list of strings
+ return True
+ elif isinstance(t[0], (list, tuple)):
+ # ... list with an empty list or with a list of strings
+ return len(t[0]) == 0 or isinstance(t[0][0], str)
+ else:
+ return False
+ else:
+ return False
+
+ if text_pair is not None:
+ # in case text + text_pair are provided, text = questions, text_pair = words
+ if not _is_valid_text_input(text):
+ raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
+ if not isinstance(text_pair, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+ else:
+ # in case only text is provided => must be words
+ if not isinstance(text, (list, tuple)):
+ raise ValueError(
+ "Words must be of type `List[str]` (single pretokenized example), "
+ "or `List[List[str]]` (batch of pretokenized examples)."
+ )
+
+ if text_pair is not None:
+ is_batched = isinstance(text, (list, tuple))
+ else:
+ is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+
+ words = text if text_pair is None else text_pair
+ if boxes is None:
+ raise ValueError("You must provide corresponding bounding boxes")
+ if is_batched:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide words and boxes for an equal amount of examples")
+ for words_example, boxes_example in zip(words, boxes):
+ if len(words_example) != len(boxes_example):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+ else:
+ if len(words) != len(boxes):
+ raise ValueError("You must provide as many words as there are bounding boxes")
+
+ if is_batched:
+ if text_pair is not None and len(text) != len(text_pair):
+ raise ValueError(
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
+ )
+ batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+ is_pair = bool(text_pair is not None)
+ return self.batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ text=text,
+ text_pair=text_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.batch_encode_plus
+ def batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[Union[List[int], List[List[int]]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ batch_text_or_text_pairs=batch_text_or_text_pairs,
+ is_pair=is_pair,
+ boxes=boxes,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.tokenize
+ def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
+ batched_input = [(text, pair)] if pair else [text]
+ encodings = self._tokenizer.encode_batch(
+ batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
+ )
+
+ return encodings[0].tokens
+
+ @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.encode_plus
+ def encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
+ `__call__` should be used instead.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
+ text_pair (`List[str]` or `List[int]`, *optional*):
+ Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
+ list of list of strings (words of a batch of examples).
+ """
+
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ text=text,
+ boxes=boxes,
+ text_pair=text_pair,
+ word_labels=word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._batch_encode_plus with LayoutLMv2->LayoutLMv3
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput],
+ List[TextInputPair],
+ List[PreTokenizedInput],
+ ],
+ is_pair: bool = None,
+ boxes: Optional[List[List[List[int]]]] = None,
+ word_labels: Optional[List[List[int]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+
+ if not isinstance(batch_text_or_text_pairs, list):
+ raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
+
+ # Set the truncation and padding strategy and restore the initial configuration
+ self.set_truncation_and_padding(
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ )
+
+ if is_pair:
+ batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
+
+ encodings = self._tokenizer.encode_batch(
+ batch_text_or_text_pairs,
+ add_special_tokens=add_special_tokens,
+ is_pretokenized=True, # we set this to True as LayoutLMv3 always expects pretokenized inputs
+ )
+
+ # Convert encoding to dict
+ # `Tokens` has type: Tuple[
+ # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
+ # List[EncodingFast]
+ # ]
+ # with nested dimensions corresponding to batch, overflows, sequence length
+ tokens_and_encodings = [
+ self._convert_encoding(
+ encoding=encoding,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=True
+ if word_labels is not None
+ else return_offsets_mapping, # we use offsets to create the labels
+ return_length=return_length,
+ verbose=verbose,
+ )
+ for encoding in encodings
+ ]
+
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+ # (we say ~ because the number of overflow varies with the example in the batch)
+ #
+ # To match each overflowing sample with the original sample in the batch
+ # we add an overflow_to_sample_mapping array (see below)
+ sanitized_tokens = {}
+ for key in tokens_and_encodings[0][0].keys():
+ stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+ sanitized_tokens[key] = stack
+ sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+ # If returning overflowing tokens, we need to return a mapping
+ # from the batch idx to the original sample
+ if return_overflowing_tokens:
+ overflow_to_sample_mapping = []
+ for i, (toks, _) in enumerate(tokens_and_encodings):
+ overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+ for input_ids in sanitized_tokens["input_ids"]:
+ self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+
+ # create the token boxes
+ token_boxes = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ token_boxes_example = []
+ for id, sequence_id, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_encodings[batch_index].sequence_ids,
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if is_pair and sequence_id == 0:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ token_boxes_example.append(boxes[original_index][word_id])
+ else:
+ if id == self.cls_token_id:
+ token_boxes_example.append(self.cls_token_box)
+ elif id == self.sep_token_id:
+ token_boxes_example.append(self.sep_token_box)
+ elif id == self.pad_token_id:
+ token_boxes_example.append(self.pad_token_box)
+ else:
+ raise ValueError("Id not recognized")
+ token_boxes.append(token_boxes_example)
+
+ sanitized_tokens["bbox"] = token_boxes
+
+ # optionally, create the labels
+ if word_labels is not None:
+ labels = []
+ for batch_index in range(len(sanitized_tokens["input_ids"])):
+ if return_overflowing_tokens:
+ original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
+ else:
+ original_index = batch_index
+ labels_example = []
+ for id, offset, word_id in zip(
+ sanitized_tokens["input_ids"][batch_index],
+ sanitized_tokens["offset_mapping"][batch_index],
+ sanitized_encodings[batch_index].word_ids,
+ ):
+ if word_id is not None:
+ if self.only_label_first_subword:
+ if offset[0] == 0:
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ else:
+ labels_example.append(word_labels[original_index][word_id])
+ else:
+ labels_example.append(self.pad_token_label)
+ labels.append(labels_example)
+
+ sanitized_tokens["labels"] = labels
+ # finally, remove offsets if the user didn't want them
+ if not return_offsets_mapping:
+ del sanitized_tokens["offset_mapping"]
+
+ return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._encode_plus
+ def _encode_plus(
+ self,
+ text: Union[TextInput, PreTokenizedInput],
+ text_pair: Optional[PreTokenizedInput] = None,
+ boxes: Optional[List[List[int]]] = None,
+ word_labels: Optional[List[int]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs
+ ) -> BatchEncoding:
+
+ # make it a batched input
+ # 2 options:
+ # 1) only text, in case text must be a list of str
+ # 2) text + text_pair, in which case text = str and text_pair a list of str
+ batched_input = [(text, text_pair)] if text_pair else [text]
+ batched_boxes = [boxes]
+ batched_word_labels = [word_labels] if word_labels is not None else None
+ batched_output = self._batch_encode_plus(
+ batched_input,
+ is_pair=bool(text_pair is not None),
+ boxes=batched_boxes,
+ word_labels=batched_word_labels,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ # Return tensor is None, then we can remove the leading batch axis
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
+ if return_tensors is None and not return_overflowing_tokens:
+ batched_output = BatchEncoding(
+ {
+ key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
+ for key, value in batched_output.items()
+ },
+ batched_output.encodings,
+ )
+
+ self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+ return batched_output
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._pad
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ >= 7.5 (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ if return_attention_mask is None:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "bbox" in encoded_inputs:
+ encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
+ if "labels" in encoded_inputs:
+ encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return encoded_inputs
+
+ # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
+ if token_ids_1 is None:
+ return output
+
+ return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Args:
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not:
+ make use of token type ids, therefore a list of zeros is returned.
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
diff --git a/src/transformers/models/layoutxlm/__init__.py b/src/transformers/models/layoutxlm/__init__.py
index c9459aff203389..9c09d75d68ba25 100644
--- a/src/transformers/models/layoutxlm/__init__.py
+++ b/src/transformers/models/layoutxlm/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
@@ -27,27 +28,43 @@
)
-_import_structure = {}
+_import_structure = {"processing_layoutxlm": ["LayoutXLMProcessor"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutxlm"] = ["LayoutXLMTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_layoutxlm_fast"] = ["LayoutXLMTokenizerFast"]
-if is_vision_available():
- _import_structure["processing_layoutxlm"] = ["LayoutXLMProcessor"]
-
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ from .processing_layoutxlm import LayoutXLMProcessor
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutxlm import LayoutXLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_layoutxlm_fast import LayoutXLMTokenizerFast
- if is_vision_available():
- from .processing_layoutlmv2 import LayoutXLMProcessor
-
else:
import sys
diff --git a/src/transformers/models/layoutxlm/processing_layoutxlm.py b/src/transformers/models/layoutxlm/processing_layoutxlm.py
index 99245ccc177696..03423d17c27be2 100644
--- a/src/transformers/models/layoutxlm/processing_layoutxlm.py
+++ b/src/transformers/models/layoutxlm/processing_layoutxlm.py
@@ -86,8 +86,7 @@ def __call__(
if self.feature_extractor.apply_ocr and (word_labels is not None):
raise ValueError(
- "You cannot provide word labels "
- "if you initialized the feature extractor with apply_ocr set to True."
+ "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
)
# first, apply the feature extractor
@@ -122,6 +121,37 @@ def __call__(
)
# add pixel values
- encoded_inputs["image"] = features.pop("pixel_values")
+ images = features.pop("pixel_values")
+ if return_overflowing_tokens is True:
+ images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
+ encoded_inputs["image"] = images
return encoded_inputs
+
+ def get_overflowing_images(self, images, overflow_to_sample_mapping):
+ # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
+ images_with_overflow = []
+ for sample_idx in overflow_to_sample_mapping:
+ images_with_overflow.append(images[sample_idx])
+
+ if len(images_with_overflow) != len(overflow_to_sample_mapping):
+ raise ValueError(
+ "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
+ f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
+ )
+
+ return images_with_overflow
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
index 8fded392844daa..c0c9acfe476fa0 100644
--- a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
+++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py
@@ -438,7 +438,8 @@ def _is_valid_text_input(t):
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
@@ -960,7 +961,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_second'."
+ "for instance 'longest_first' or 'only_second'."
)
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
@@ -974,7 +975,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (
diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
index 35b438387747b2..1477d06b802bd7 100644
--- a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
+++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py
@@ -265,7 +265,8 @@ def _is_valid_text_input(t):
if is_batched:
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
is_pair = bool(text_pair is not None)
diff --git a/src/transformers/models/led/__init__.py b/src/transformers/models/led/__init__.py
index d60800f981a5f7..da871828ad8889 100644
--- a/src/transformers/models/led/__init__.py
+++ b/src/transformers/models/led/__init__.py
@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -25,10 +31,20 @@
"tokenization_led": ["LEDTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_led"] = [
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
"LEDForConditionalGeneration",
@@ -39,7 +55,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]
@@ -47,10 +68,20 @@
from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from .tokenization_led import LEDTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_led_fast import LEDTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
@@ -60,7 +91,12 @@
LEDPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
else:
diff --git a/src/transformers/models/led/configuration_led.py b/src/transformers/models/led/configuration_led.py
index 5f534ab28703f1..37720c730af1e2 100644
--- a/src/transformers/models/led/configuration_led.py
+++ b/src/transformers/models/led/configuration_led.py
@@ -86,18 +86,17 @@ class LEDConfig(PretrainedConfig):
Example:
```python
+ >>> from transformers import LEDModel, LEDConfig
- ```
+ >>> # Initializing a LED allenai/led-base-16384 style configuration
+ >>> configuration = LEDConfig()
- >>> from transformers import LEDModel, LEDConfig
+ >>> # Initializing a model from the allenai/led-base-16384 style configuration
+ >>> model = LEDModel(configuration)
- >>> # Initializing a LED allenai/led-base-16384 style configuration >>> configuration = LEDConfig()
-
- >>> # Initializing a model from the allenai/led-base-16384 style configuration >>> model =
- LEDModel(configuration)
-
- >>> # Accessing the model configuration >>> configuration = model.config
- """
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
model_type = "led"
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py
index 3e852cf2a67d55..6ac0bfccb19661 100755
--- a/src/transformers/models/led/modeling_led.py
+++ b/src/transformers/models/led/modeling_led.py
@@ -222,7 +222,10 @@ def forward(
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
- ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ ], (
+ f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ )
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
@@ -662,7 +665,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
- ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
+ ], (
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {global_attn_scores.size()}."
+ )
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -705,7 +712,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
- ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
+ ], (
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {global_attn_output.size()}."
+ )
global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_output = global_attn_output.view(
@@ -766,7 +777,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
@@ -837,7 +849,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -852,7 +865,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -873,7 +887,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = (
@@ -1007,7 +1022,7 @@ def forward(
"""
residual = hidden_states
- # Self Attention
+ # Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
@@ -1437,13 +1452,11 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
LED_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
+ This model inherits from [`PreTrainedModel`]. See the superclass documentation for the generic methods the library
+ implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for general usage and behavior.
Parameters:
config ([`LEDConfig`]):
@@ -1595,7 +1608,7 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
class LEDEncoder(LEDPreTrainedModel):
"""
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`LEDEncoderLayer`].
Args:
@@ -1643,7 +1656,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non
self.post_init()
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
- # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
+ # longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
@@ -1815,7 +1828,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -2071,7 +2085,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -2283,9 +2298,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
base_model_prefix = "led"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: LEDConfig):
@@ -2426,6 +2441,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
+ global_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@@ -2443,6 +2459,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
+ "global_attention_mask": global_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py
index a882e32ec4e7eb..83a71a0dfe8ad2 100644
--- a/src/transformers/models/led/modeling_tf_led.py
+++ b/src/transformers/models/led/modeling_tf_led.py
@@ -246,7 +246,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
- message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
+ message=(
+ f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
+ ),
)
# compute global attn indices required through out forward fn
@@ -299,7 +302,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@@ -392,7 +398,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
- message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
+ message=(
+ f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
+ f" {shape_list(key)}"
+ ),
)
chunks_count = seq_len // window_overlap - 1
@@ -677,7 +686,10 @@ def _chunk(hidden_states, window_overlap):
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
- message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
+ message=(
+ "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
+ f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
+ ),
)
chunked_hidden_states = tf.reshape(
@@ -855,7 +867,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
- message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
+ message=(
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {shape_list(global_attn_scores)}."
+ ),
)
global_attn_scores = tf.reshape(
@@ -894,7 +910,10 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -913,7 +932,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
- message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
+ message=(
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {shape_list(global_attn_output)}."
+ ),
)
global_attn_output = tf.reshape(
@@ -1069,7 +1092,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -1077,7 +1103,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
@@ -1092,7 +1121,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -1108,7 +1140,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -1238,7 +1273,7 @@ def call(
"""
residual = hidden_states
- # Self Attention
+ # Self-Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
@@ -1612,7 +1647,7 @@ class TFLEDSeq2SeqLMOutput(ModelOutput):
class TFLEDEncoder(tf.keras.layers.Layer):
config_class = LEDConfig
"""
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
[`TFLEDEncoderLayer`].
Args:
@@ -1753,7 +1788,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -2013,7 +2051,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
diff --git a/src/transformers/models/levit/__init__.py b/src/transformers/models/levit/__init__.py
new file mode 100644
index 00000000000000..bdbcaed41a15bd
--- /dev/null
+++ b/src/transformers/models/levit/__init__.py
@@ -0,0 +1,75 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_levit"] = ["LevitFeatureExtractor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_levit"] = [
+ "LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LevitForImageClassification",
+ "LevitForImageClassificationWithTeacher",
+ "LevitModel",
+ "LevitPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_levit import LevitFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_levit import (
+ LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ LevitPreTrainedModel,
+ )
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/levit/configuration_levit.py b/src/transformers/models/levit/configuration_levit.py
new file mode 100644
index 00000000000000..5d75b9fc23e759
--- /dev/null
+++ b/src/transformers/models/levit/configuration_levit.py
@@ -0,0 +1,122 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LeViT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/levit-128S": "https://huggingface.co/facebook/levit-128S/resolve/main/config.json",
+ # See all LeViT models at https://huggingface.co/models?filter=levit
+}
+
+
+class LevitConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LevitModel`]. It is used to instantiate a LeViT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LeViT
+ [facebook/levit-base-192](https://huggingface.co/facebook/levit-base-192) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size of the input image.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the initial convolution layers of patch embedding.
+ stride (`int`, *optional*, defaults to 2):
+ The stride size for the initial convolution layers of patch embedding.
+ padding (`int`, *optional*, defaults to 1):
+ The padding size for the initial convolution layers of patch embedding.
+ patch_size (`int`, *optional*, defaults to 16):
+ The patch size for embeddings.
+ hidden_sizes (`List[int]`, *optional*, defaults to `[128, 256, 384]`):
+ Dimension of each of the encoder blocks.
+ num_attention_heads (`List[int]`, *optional*, defaults to `[4, 8, 12]`):
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
+ depths (`List[int]`, *optional*, defaults to `[4, 4, 4]`):
+ The number of layers in each encoder block.
+ key_dim (`List[int]`, *optional*, defaults to `[16, 16, 16]`):
+ The size of key in each of the encoder blocks.
+ drop_path_rate (`int`, *optional*, defaults to 0):
+ The dropout probability for stochastic depths, used in the blocks of the Transformer encoder.
+ mlp_ratios (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
+ encoder blocks.
+ attention_ratios (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Ratio of the size of the output dimension compared to input dimension of attention layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ Example:
+
+ ```python
+ >>> from transformers import LevitModel, LevitConfig
+
+ >>> # Initializing a LeViT levit-base-192 style configuration
+ >>> configuration = LevitConfig()
+
+ >>> # Initializing a model from the levit-base-192 style configuration
+ >>> model = LevitModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "levit"
+
+ def __init__(
+ self,
+ image_size=224,
+ num_channels=3,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ patch_size=16,
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 8, 12],
+ depths=[4, 4, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ mlp_ratio=[2, 2, 2],
+ attention_ratio=[2, 2, 2],
+ initializer_range=0.02,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.hidden_sizes = hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.depths = depths
+ self.key_dim = key_dim
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size
+ self.attention_ratio = attention_ratio
+ self.mlp_ratio = mlp_ratio
+ self.initializer_range = initializer_range
+ self.down_ops = [
+ ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
+ ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
+ ]
diff --git a/src/transformers/models/levit/convert_levit_timm_to_pytorch.py b/src/transformers/models/levit/convert_levit_timm_to_pytorch.py
new file mode 100644
index 00000000000000..d9449aad7ab1d9
--- /dev/null
+++ b/src/transformers/models/levit/convert_levit_timm_to_pytorch.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert LeViT checkpoints from timm."""
+
+
+import argparse
+import json
+from collections import OrderedDict
+from functools import partial
+from pathlib import Path
+
+import torch
+
+import timm
+from huggingface_hub import hf_hub_download
+from transformers import LevitConfig, LevitFeatureExtractor, LevitForImageClassificationWithTeacher
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger()
+
+
+def convert_weight_and_push(
+ hidden_sizes: int, name: str, config: LevitConfig, save_directory: Path, push_to_hub: bool = True
+):
+ print(f"Converting {name}...")
+
+ with torch.no_grad():
+ if hidden_sizes == 128:
+ if name[-1] == "S":
+ from_model = timm.create_model("levit_128s", pretrained=True)
+ else:
+ from_model = timm.create_model("levit_128", pretrained=True)
+ if hidden_sizes == 192:
+ from_model = timm.create_model("levit_192", pretrained=True)
+ if hidden_sizes == 256:
+ from_model = timm.create_model("levit_256", pretrained=True)
+ if hidden_sizes == 384:
+ from_model = timm.create_model("levit_384", pretrained=True)
+
+ from_model.eval()
+ our_model = LevitForImageClassificationWithTeacher(config).eval()
+ huggingface_weights = OrderedDict()
+
+ weights = from_model.state_dict()
+ og_keys = list(from_model.state_dict().keys())
+ new_keys = list(our_model.state_dict().keys())
+ print(len(og_keys), len(new_keys))
+ for i in range(len(og_keys)):
+ huggingface_weights[new_keys[i]] = weights[og_keys[i]]
+ our_model.load_state_dict(huggingface_weights)
+
+ x = torch.randn((2, 3, 224, 224))
+ out1 = from_model(x)
+ out2 = our_model(x).logits
+
+ assert torch.allclose(out1, out2), "The model logits don't match the original one."
+
+ checkpoint_name = name
+ print(checkpoint_name)
+
+ if push_to_hub:
+ our_model.save_pretrained(save_directory / checkpoint_name)
+ feature_extractor = LevitFeatureExtractor()
+ feature_extractor.save_pretrained(save_directory / checkpoint_name)
+
+ print(f"Pushed {checkpoint_name}")
+
+
+def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
+ filename = "imagenet-1k-id2label.json"
+ num_labels = 1000
+ expected_shape = (1, num_labels)
+
+ repo_id = "datasets/huggingface/label-files"
+ num_labels = num_labels
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+
+ id2label = id2label
+ label2id = {v: k for k, v in id2label.items()}
+
+ ImageNetPreTrainedConfig = partial(LevitConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
+
+ names_to_hidden_sizes = {
+ "levit-128S": 128,
+ "levit-128": 128,
+ "levit-192": 192,
+ "levit-256": 256,
+ "levit-384": 384,
+ }
+
+ names_to_config = {
+ "levit-128S": ImageNetPreTrainedConfig(
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 6, 8],
+ depths=[2, 3, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ ),
+ "levit-128": ImageNetPreTrainedConfig(
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 8, 12],
+ depths=[4, 4, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ ),
+ "levit-192": ImageNetPreTrainedConfig(
+ hidden_sizes=[192, 288, 384],
+ num_attention_heads=[3, 5, 6],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0,
+ ),
+ "levit-256": ImageNetPreTrainedConfig(
+ hidden_sizes=[256, 384, 512],
+ num_attention_heads=[4, 6, 8],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0,
+ ),
+ "levit-384": ImageNetPreTrainedConfig(
+ hidden_sizes=[384, 512, 768],
+ num_attention_heads=[6, 9, 12],
+ depths=[4, 4, 4],
+ key_dim=[32, 32, 32],
+ drop_path_rate=0.1,
+ ),
+ }
+
+ if model_name:
+ convert_weight_and_push(
+ names_to_hidden_sizes[model_name], model_name, names_to_config[model_name], save_directory, push_to_hub
+ )
+ else:
+ for model_name, config in names_to_config.items():
+ convert_weight_and_push(names_to_hidden_sizes[model_name], model_name, config, save_directory, push_to_hub)
+ return config, expected_shape
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default=None,
+ type=str,
+ help="The name of the model you wish to convert, it must be one of the supported Levit* architecture,",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="levit-dump-folder/",
+ type=Path,
+ required=False,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ default=True,
+ type=bool,
+ required=False,
+ help="If True, push model and feature extractor to the hub.",
+ )
+
+ args = parser.parse_args()
+ pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
+ pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
+ convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
diff --git a/src/transformers/models/levit/feature_extraction_levit.py b/src/transformers/models/levit/feature_extraction_levit.py
new file mode 100644
index 00000000000000..b0ac5f6b3d3030
--- /dev/null
+++ b/src/transformers/models/levit/feature_extraction_levit.py
@@ -0,0 +1,158 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for LeViT."""
+
+from typing import Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ImageFeatureExtractionMixin,
+ ImageInput,
+ is_torch_tensor,
+)
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LevitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a LeViT feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the shortest edge of the input to int(256/224 *`size`).
+ size (`int` or `Tuple(int)`, *optional*, defaults to 224):
+ Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
+ integer is provided, then shorter side of input will be resized to 'size'.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether or not to center crop the input to `size`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BICUBIC,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=IMAGENET_DEFAULT_MEAN,
+ image_std=IMAGENET_DEFAULT_STD,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def __call__(
+ self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None:
+ size_ = int((256 / 224) * self.size)
+ images = [
+ self.resize(image=image, size=size_, resample=self.resample, default_to_square=False)
+ for image in images
+ ]
+ if self.do_center_crop:
+ images = [self.center_crop(image=image, size=self.size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py
new file mode 100644
index 00000000000000..b04a98317d7899
--- /dev/null
+++ b/src/transformers/models/levit/modeling_levit.py
@@ -0,0 +1,738 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LeViT model."""
+
+import itertools
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+ ModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_levit import LevitConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "LevitConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "LevitFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/levit-128S"
+_EXPECTED_OUTPUT_SHAPE = [1, 16, 384]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/levit-128S",
+ # See all LeViT models at https://huggingface.co/models?filter=levit
+]
+
+
+@dataclass
+class LevitForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Output type of [`LevitForImageClassificationWithTeacher`].
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the `cls_logits` and `distillation_logits`.
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ """
+
+ logits: torch.FloatTensor = None
+ cls_logits: torch.FloatTensor = None
+ distillation_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class LevitConvEmbeddings(nn.Module):
+ """
+ LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
+ ):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
+ )
+ self.batch_norm = nn.BatchNorm2d(out_channels)
+
+ def forward(self, embeddings):
+ embeddings = self.convolution(embeddings)
+ embeddings = self.batch_norm(embeddings)
+ return embeddings
+
+
+class LevitPatchEmbeddings(nn.Module):
+ """
+ LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
+ `LevitConvEmbeddings`.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.embedding_layer_1 = LevitConvEmbeddings(
+ config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_1 = nn.Hardswish()
+
+ self.embedding_layer_2 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_2 = nn.Hardswish()
+
+ self.embedding_layer_3 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
+ )
+ self.activation_layer_3 = nn.Hardswish()
+
+ self.embedding_layer_4 = LevitConvEmbeddings(
+ config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
+ )
+
+ def forward(self, pixel_values):
+ embeddings = self.embedding_layer_1(pixel_values)
+ embeddings = self.activation_layer_1(embeddings)
+ embeddings = self.embedding_layer_2(embeddings)
+ embeddings = self.activation_layer_2(embeddings)
+ embeddings = self.embedding_layer_3(embeddings)
+ embeddings = self.activation_layer_3(embeddings)
+ embeddings = self.embedding_layer_4(embeddings)
+ return embeddings.flatten(2).transpose(1, 2)
+
+
+class MLPLayerWithBN(nn.Module):
+ def __init__(self, input_dim, output_dim, bn_weight_init=1):
+ super().__init__()
+ self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
+ self.batch_norm = nn.BatchNorm1d(output_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.linear(hidden_state)
+ hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
+ return hidden_state
+
+
+class LevitSubsample(nn.Module):
+ def __init__(self, stride, resolution):
+ super().__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, hidden_state):
+ batch_size, _, channels = hidden_state.shape
+ hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
+ :, :: self.stride, :: self.stride
+ ].reshape(batch_size, -1, channels)
+ return hidden_state
+
+
+class LevitAttention(nn.Module):
+ def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
+ self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
+
+ self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
+ self.activation = nn.Hardswish()
+ self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ len_points = len(points)
+ attention_offsets, indices = {}, []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ indices.append(attention_offsets[offset])
+
+ self.attention_bias_cache = {}
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device):
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, hidden_state):
+ batch_size, seq_length, _ = hidden_state.shape
+ queries_keys_values = self.queries_keys_values(hidden_state)
+ query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
+ [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
+ )
+ query = query.permute(0, 2, 1, 3)
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
+ attention = attention.softmax(dim=-1)
+ hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
+ hidden_state = self.projection(self.activation(hidden_state))
+ return hidden_state
+
+
+class LevitAttentionSubsample(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ key_dim,
+ num_attention_heads,
+ attention_ratio,
+ stride,
+ resolution_in,
+ resolution_out,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
+ self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
+ self.resolution_out = resolution_out
+ # resolution_in is the intial resolution, resoloution_out is final resolution after downsampling
+ self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
+ self.queries_subsample = LevitSubsample(stride, resolution_in)
+ self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
+ self.activation = nn.Hardswish()
+ self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
+
+ self.attention_bias_cache = {}
+
+ points = list(itertools.product(range(resolution_in), range(resolution_in)))
+ points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
+ len_points, len_points_ = len(points), len(points_)
+ attention_offsets, indices = {}, []
+ for p1 in points_:
+ for p2 in points:
+ size = 1
+ offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ indices.append(attention_offsets[offset])
+
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device):
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, hidden_state):
+ batch_size, seq_length, _ = hidden_state.shape
+ key, value = (
+ self.keys_values(hidden_state)
+ .view(batch_size, seq_length, self.num_attention_heads, -1)
+ .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
+ )
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ query = self.queries(self.queries_subsample(hidden_state))
+ query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
+ 0, 2, 1, 3
+ )
+
+ attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
+ attention = attention.softmax(dim=-1)
+ hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
+ hidden_state = self.projection(self.activation(hidden_state))
+ return hidden_state
+
+
+class LevitMLPLayer(nn.Module):
+ """
+ MLP Layer with `2X` expansion in contrast to ViT with `4X`.
+ """
+
+ def __init__(self, input_dim, hidden_dim):
+ super().__init__()
+ self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
+ self.activation = nn.Hardswish()
+ self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.linear_up(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.linear_down(hidden_state)
+ return hidden_state
+
+
+class LevitResidualLayer(nn.Module):
+ """
+ Residual Block for LeViT
+ """
+
+ def __init__(self, module, drop_rate):
+ super().__init__()
+ self.module = module
+ self.drop_rate = drop_rate
+
+ def forward(self, hidden_state):
+ if self.training and self.drop_rate > 0:
+ rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
+ rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
+ hidden_state = hidden_state + self.module(hidden_state) * rnd
+ return hidden_state
+ else:
+ hidden_state = hidden_state + self.module(hidden_state)
+ return hidden_state
+
+
+class LevitStage(nn.Module):
+ """
+ LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
+ """
+
+ def __init__(
+ self,
+ config,
+ idx,
+ hidden_sizes,
+ key_dim,
+ depths,
+ num_attention_heads,
+ attention_ratio,
+ mlp_ratio,
+ down_ops,
+ resolution_in,
+ ):
+ super().__init__()
+ self.layers = []
+ self.config = config
+ self.resolution_in = resolution_in
+ # resolution_in is the intial resolution, resolution_out is final resolution after downsampling
+ for _ in range(depths):
+ self.layers.append(
+ LevitResidualLayer(
+ LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
+ self.config.drop_path_rate,
+ )
+ )
+ if mlp_ratio > 0:
+ hidden_dim = hidden_sizes * mlp_ratio
+ self.layers.append(
+ LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
+ )
+
+ if down_ops[0] == "Subsample":
+ self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
+ self.layers.append(
+ LevitAttentionSubsample(
+ *self.config.hidden_sizes[idx : idx + 2],
+ key_dim=down_ops[1],
+ num_attention_heads=down_ops[2],
+ attention_ratio=down_ops[3],
+ stride=down_ops[5],
+ resolution_in=resolution_in,
+ resolution_out=self.resolution_out,
+ )
+ )
+ self.resolution_in = self.resolution_out
+ if down_ops[4] > 0:
+ hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
+ self.layers.append(
+ LevitResidualLayer(
+ LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
+ )
+ )
+
+ self.layers = nn.ModuleList(self.layers)
+
+ def get_resolution(self):
+ return self.resolution_in
+
+ def forward(self, hidden_state):
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class LevitEncoder(nn.Module):
+ """
+ LeViT Encoder consisting of multiple `LevitStage` stages.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ resolution = self.config.image_size // self.config.patch_size
+ self.stages = []
+ self.config.down_ops.append([""])
+
+ for stage_idx in range(len(config.depths)):
+ stage = LevitStage(
+ config,
+ stage_idx,
+ config.hidden_sizes[stage_idx],
+ config.key_dim[stage_idx],
+ config.depths[stage_idx],
+ config.num_attention_heads[stage_idx],
+ config.attention_ratio[stage_idx],
+ config.mlp_ratio[stage_idx],
+ config.down_ops[stage_idx],
+ resolution,
+ )
+ resolution = stage.get_resolution()
+ self.stages.append(stage)
+
+ self.stages = nn.ModuleList(self.stages)
+
+ def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ for stage in self.stages:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+ hidden_state = stage(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+ if not return_dict:
+ return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
+
+
+class LevitClassificationLayer(nn.Module):
+ """
+ LeViT Classification Layer
+ """
+
+ def __init__(self, input_dim, output_dim):
+ super().__init__()
+ self.batch_norm = nn.BatchNorm1d(input_dim)
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, hidden_state):
+ hidden_state = self.batch_norm(hidden_state)
+ logits = self.linear(hidden_state)
+ return logits
+
+
+class LevitPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LevitConfig
+ base_model_prefix = "levit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LevitModel):
+ module.gradient_checkpointing = value
+
+
+LEVIT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`LevitConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LEVIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Levit model outputting raw features without any specific head on top.",
+ LEVIT_START_DOCSTRING,
+)
+class LevitModel(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.patch_embeddings = LevitPatchEmbeddings(config)
+ self.encoder = LevitEncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embeddings = self.patch_embeddings(pixel_values)
+ encoder_outputs = self.encoder(
+ embeddings,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
+ pooled_output = last_hidden_state.mean(dim=1)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ LEVIT_START_DOCSTRING,
+)
+class LevitForImageClassification(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.num_labels = config.num_labels
+ self.levit = LevitModel(config)
+
+ # Classifier head
+ self.classifier = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ sequence_output = outputs[0]
+ sequence_output = sequence_output.mean(1)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
+ a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """,
+ LEVIT_START_DOCSTRING,
+)
+class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.num_labels = config.num_labels
+ self.levit = LevitModel(config)
+
+ # Classifier head
+ self.classifier = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+ self.classifier_distill = (
+ LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
+ if config.num_labels > 0
+ else torch.nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=LevitForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ sequence_output = outputs[0]
+ sequence_output = sequence_output.mean(1)
+ cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
+ logits = (cls_logits + distill_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distill_logits) + outputs[2:]
+ return output
+
+ return LevitForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distill_logits,
+ hidden_states=outputs.hidden_states,
+ )
diff --git a/src/transformers/models/longformer/__init__.py b/src/transformers/models/longformer/__init__.py
index 329b8f1cdf9207..1705703b5ac333 100644
--- a/src/transformers/models/longformer/__init__.py
+++ b/src/transformers/models/longformer/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -30,10 +36,20 @@
"tokenization_longformer": ["LongformerTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_longformer_fast"] = ["LongformerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_longformer"] = [
"LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"LongformerForMaskedLM",
@@ -46,7 +62,12 @@
"LongformerSelfAttention",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_longformer"] = [
"TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLongformerForMaskedLM",
@@ -68,10 +89,20 @@
)
from .tokenization_longformer import LongformerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_longformer_fast import LongformerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
@@ -84,7 +115,12 @@
LongformerSelfAttention,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
diff --git a/src/transformers/models/longformer/configuration_longformer.py b/src/transformers/models/longformer/configuration_longformer.py
index 2c9fd17b35ec19..53ceeafb64bad2 100644
--- a/src/transformers/models/longformer/configuration_longformer.py
+++ b/src/transformers/models/longformer/configuration_longformer.py
@@ -24,9 +24,15 @@
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/config.json",
"allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/config.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/config.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/config.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/config.json",
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/config.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/config.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py
index 647bb8fb7319d4..30db98dea1abf8 100755
--- a/src/transformers/models/longformer/modeling_longformer.py
+++ b/src/transformers/models/longformer/modeling_longformer.py
@@ -388,9 +388,10 @@ def _get_question_end_index(input_ids, sep_token_id):
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
- assert (
- sep_token_indices.shape[0] == 3 * batch_size
- ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
+ assert sep_token_indices.shape[0] == 3 * batch_size, (
+ f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You"
+ " might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
+ )
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@@ -600,7 +601,10 @@ def forward(
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
- ], f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ ], (
+ f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
+ )
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
@@ -1040,7 +1044,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
- ], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
+ ], (
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {global_attn_scores.size()}."
+ )
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -1083,7 +1091,11 @@ def _compute_global_attn_output_from_hidden(
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
- ], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
+ ], (
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {global_attn_output.size()}."
+ )
global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_output = global_attn_output.view(
@@ -1634,10 +1646,10 @@ def forward(
>>> attention_mask = torch.ones(
... input_ids.shape, dtype=torch.long, device=input_ids.device
- >>> ) # initialize to local attention
+ ... ) # initialize to local attention
>>> global_attention_mask = torch.zeros(
... input_ids.shape, dtype=torch.long, device=input_ids.device
- >>> ) # initialize to global attention to be deactivated for all tokens
+ ... ) # initialize to global attention to be deactivated for all tokens
>>> global_attention_mask[
... :,
... [
@@ -1645,7 +1657,7 @@ def forward(
... 4,
... 21,
... ],
- >>> ] = 1 # Set global attention to random tokens for the sake of this example
+ ... ] = 1 # Set global attention to random tokens for the sake of this example
>>> # Usually, set global attention based on the task. For example,
>>> # classification: the token
>>> # QA: question tokens
@@ -1692,7 +1704,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, :
]
@@ -1770,23 +1782,31 @@ def forward(
Returns:
- Examples:
+ Mask filling example:
```python
- >>> import torch
- >>> from transformers import LongformerForMaskedLM, LongformerTokenizer
+ >>> from transformers import LongformerTokenizer, LongformerForMaskedLM
- >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
+ >>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
+ ```
- >>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
- >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
+ Let's try a very long input.
- >>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
- >>> # check `LongformerModel.forward` for more details how to set *attention_mask*
- >>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
- >>> loss = outputs.loss
- >>> prediction_logits = outputs.logits
+ ```python
+ >>> TXT = (
+ ... "My friends are but they eat too many carbs."
+ ... + " That's why I decide not to eat with them." * 300
+ ... )
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = logits[0, masked_index].softmax(dim=0)
+ >>> values, predictions = probs.topk(5)
+
+ >>> tokenizer.decode(predictions).split()
+ ['healthy', 'skinny', 'thin', 'good', 'vegetarian']
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1848,9 +1868,11 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="jpelhaw/longformer-base-plagiarism-detection",
output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="'ORIGINAL'",
+ expected_loss=5.44,
)
def forward(
self,
@@ -2025,14 +2047,15 @@ def forward(
>>> answer_tokens = all_tokens[torch.argmax(start_logits) : torch.argmax(end_logits) + 1]
>>> answer = tokenizer.decode(
... tokenizer.convert_tokens_to_ids(answer_tokens)
- >>> ) # remove space prepending space token
+ ... ) # remove space prepending space token
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if global_attention_mask is None:
if input_ids is None:
logger.warning(
- "It is not possible to automatically generate the `global_attention_mask` because input_ids is None. Please make sure that it is correctly set."
+ "It is not possible to automatically generate the `global_attention_mask` because input_ids is"
+ " None. Please make sure that it is correctly set."
)
else:
# set global attention on question tokens automatically
@@ -2114,9 +2137,14 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="brad1141/Longformer-finetuned-norm",
output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output=(
+ "['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence',"
+ " 'Evidence', 'Evidence', 'Evidence', 'Evidence']"
+ ),
+ expected_loss=0.63,
)
def forward(
self,
diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py
index 124fe2c06fecbc..0dfd9c66617f29 100644
--- a/src/transformers/models/longformer/modeling_tf_longformer.py
+++ b/src/transformers/models/longformer/modeling_tf_longformer.py
@@ -775,7 +775,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
- message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
+ message=(
+ f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
+ f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
+ ),
)
# compute global attn indices required through out forward fn
@@ -828,7 +831,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@@ -921,7 +927,10 @@ def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
- message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
+ message=(
+ f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
+ f" {shape_list(key)}"
+ ),
)
chunks_count = seq_len // window_overlap - 1
@@ -1206,7 +1215,10 @@ def _chunk(hidden_states, window_overlap):
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
- message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
+ message=(
+ "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
+ f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
+ ),
)
chunked_hidden_states = tf.reshape(
@@ -1384,7 +1396,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
- message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
+ message=(
+ "global_attn_scores have the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
+ f" {shape_list(global_attn_scores)}."
+ ),
)
global_attn_scores = tf.reshape(
@@ -1423,7 +1439,10 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
@@ -1442,7 +1461,11 @@ def _compute_global_attn_output_from_hidden(
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
- message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
+ message=(
+ "global_attn_output tensor has the wrong size. Size should be"
+ f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
+ f" {shape_list(global_attn_output)}."
+ ),
)
global_attn_output = tf.reshape(
@@ -2079,10 +2102,12 @@ def get_prefix_bias_name(self):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="",
+ expected_output="' Paris'",
+ expected_loss=0.44,
)
def call(
self,
@@ -2175,6 +2200,8 @@ def __init__(self, config, *inputs, **kwargs):
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
output_type=TFLongformerQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="' puppet'",
+ expected_loss=0.96,
)
def call(
self,
@@ -2207,7 +2234,10 @@ def call(
if global_attention_mask is None and input_ids is not None:
if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]:
logger.warning(
- f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
+ f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for"
+ " questions answering. You might also consider to set `global_attention_mask` manually in the"
+ " forward function to avoid this. This is most likely an error. The global attention is disabled"
+ " for this forward pass."
)
global_attention_mask = tf.fill(shape_list(input_ids), value=0)
else:
@@ -2318,9 +2348,11 @@ def __init__(self, config, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="hf-internal-testing/tiny-random-longformer",
output_type=TFLongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output="'LABEL_1'",
+ expected_loss=0.69,
)
def call(
self,
@@ -2556,9 +2588,15 @@ def __init__(self, config, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
+ checkpoint="hf-internal-testing/tiny-random-longformer",
output_type=TFLongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
+ expected_output=(
+ "['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
+ " 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
+ " 'LABEL_1', 'LABEL_1']"
+ ),
+ expected_loss=0.59,
)
def call(
self,
diff --git a/src/transformers/models/longformer/tokenization_longformer.py b/src/transformers/models/longformer/tokenization_longformer.py
index 19445622b821f7..b594580647a228 100644
--- a/src/transformers/models/longformer/tokenization_longformer.py
+++ b/src/transformers/models/longformer/tokenization_longformer.py
@@ -25,17 +25,33 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
},
}
diff --git a/src/transformers/models/longformer/tokenization_longformer_fast.py b/src/transformers/models/longformer/tokenization_longformer_fast.py
index a7d06b1fc3db92..45a88839711736 100644
--- a/src/transformers/models/longformer/tokenization_longformer_fast.py
+++ b/src/transformers/models/longformer/tokenization_longformer_fast.py
@@ -26,24 +26,50 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt",
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
+ ),
},
"tokenizer_file": {
- "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096": "https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096-finetuned-triviaqa": "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json",
- "allenai/longformer-base-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json",
- "allenai/longformer-large-4096-extra.pos.embd.only": "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json",
+ "allenai/longformer-base-4096": (
+ "https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096": (
+ "https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096-finetuned-triviaqa": (
+ "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-base-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
+ ),
+ "allenai/longformer-large-4096-extra.pos.embd.only": (
+ "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/longt5/__init__.py b/src/transformers/models/longt5/__init__.py
new file mode 100644
index 00000000000000..fd355f6d5a9382
--- /dev/null
+++ b/src/transformers/models/longt5/__init__.py
@@ -0,0 +1,88 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_longt5": ["LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongT5Config", "LongT5OnnxConfig"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_longt5"] = [
+ "LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "LongT5EncoderModel",
+ "LongT5ForConditionalGeneration",
+ "LongT5Model",
+ "LongT5PreTrainedModel",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_longt5"] = [
+ "FlaxLongT5ForConditionalGeneration",
+ "FlaxLongT5Model",
+ "FlaxLongT5PreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_longt5 import LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP, LongT5Config, LongT5OnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_longt5 import (
+ LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ LongT5PreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_longt5 import (
+ FlaxLongT5ForConditionalGeneration,
+ FlaxLongT5Model,
+ FlaxLongT5PreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py
new file mode 100644
index 00000000000000..705fdc4939584b
--- /dev/null
+++ b/src/transformers/models/longt5/configuration_longt5.py
@@ -0,0 +1,178 @@
+# coding=utf-8
+# Copyright 2022, The LongT5 Authors and HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" LongT5 model configuration"""
+from typing import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxSeq2SeqConfigWithPast
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "google/long-t5-local-base": "https://huggingface.co/google/long-t5-local-base/blob/main/config.json",
+ "google/long-t5-local-large": "https://huggingface.co/google/long-t5-local-large/blob/main/config.json",
+ "google/long-t5-tglobal-base": "https://huggingface.co/google/long-t5-tglobal-base/blob/main/config.json",
+ "google/long-t5-tglobal-large": "https://huggingface.co/google/long-t5-tglobal-large/blob/main/config.json",
+}
+
+
+class LongT5Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
+ used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
+ [google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Arguments:
+ vocab_size (`int`, *optional*, defaults to 32128):
+ Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LongT5Model`].
+ d_model (`int`, *optional*, defaults to 512):
+ Size of the encoder layers and the pooler layer.
+ d_kv (`int`, *optional*, defaults to 64):
+ Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
+ num_heads`.
+ d_ff (`int`, *optional*, defaults to 2048):
+ Size of the intermediate feed forward layer in each `LongT5Block`.
+ num_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ num_decoder_layers (`int`, *optional*):
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
+ num_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ local_radius (`int`, *optional*, defaults to 127)
+ Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism.
+ global_block_size (`int`, *optional*, defaults to 16)
+ Lenght of blocks an input sequence is divided into for a global token representation. Used only for
+ `encoder_attention_type = "transient-global"`.
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
+ The number of buckets to use for each attention layer.
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
+ The maximum distance of the longer sequences for the bucket separation.
+ dropout_rate (`float`, *optional*, defaults to 0.1):
+ The ratio for all dropout layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the
+ `"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`.
+ encoder_attention_type (`string`, *optional*, defaults to `"local"`):
+ Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are
+ supported by LongT5 implementation.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ """
+ model_type = "longt5"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
+
+ def __init__(
+ self,
+ vocab_size=32128,
+ d_model=512,
+ d_kv=64,
+ d_ff=2048,
+ num_layers=6,
+ num_decoder_layers=None,
+ num_heads=8,
+ local_radius=127,
+ global_block_size=16,
+ relative_attention_num_buckets=32,
+ relative_attention_max_distance=128,
+ dropout_rate=0.1,
+ layer_norm_epsilon=1e-6,
+ initializer_factor=1.0,
+ feed_forward_proj="relu",
+ is_encoder_decoder=True,
+ encoder_attention_type="local",
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ **kwargs
+ ):
+
+ self.vocab_size = vocab_size
+ self.d_model = d_model
+ self.d_kv = d_kv
+ self.d_ff = d_ff
+ self.num_layers = num_layers
+ # default = symmetry
+ self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
+ self.num_heads = num_heads
+ self.local_radius = local_radius
+ self.global_block_size = global_block_size
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.relative_attention_max_distance = relative_attention_max_distance
+ self.dropout_rate = dropout_rate
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_factor = initializer_factor
+ self.feed_forward_proj = feed_forward_proj
+ self.encoder_attention_type = encoder_attention_type
+ self.use_cache = use_cache
+
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+
+class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ common_inputs = {
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
+ }
+ if self.use_past:
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
+ common_inputs["decoder_input_ids"] = {0: "batch"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
+ else:
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
+
+ if self.use_past:
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
+
+ return common_inputs
+
+ @property
+ def default_onnx_opset(self) -> int:
+ return 13
diff --git a/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py
new file mode 100644
index 00000000000000..41cc3a2005dd94
--- /dev/null
+++ b/src/transformers/models/longt5/convert_longt5x_checkpoint_to_flax.py
@@ -0,0 +1,214 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Convert T5/LongT5X checkpoints from the original repository to JAX/FLAX model. This script is an extension of
+'src/transformers/models/t5/convert_t5x_checkpoint_to_flax.
+"""
+
+import argparse
+
+from t5x import checkpoints
+from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
+
+
+def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
+ config = AutoConfig.from_pretrained(config_name)
+ flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
+
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
+
+ if config.model_type == "t5":
+ encoder_attn_name = "SelfAttention"
+ if config.model_type == "longt5" and config.encoder_attention_type == "local":
+ encoder_attn_name = "LocalSelfAttention"
+ elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ encoder_attn_name = "TransientGlobalSelfAttention"
+ else:
+ raise ValueError(
+ "Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
+ " attribute with a value from ['local', 'transient-global]."
+ )
+
+ # Encoder
+ for layer_index in range(config.num_layers):
+ layer_name = f"layers_{str(layer_index)}"
+
+ # Self-Attention
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
+
+ # Global input layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ t5x_global_layer_norm = t5x_model["target"]["encoder"][layer_name]["attention"]["T5LayerNorm_0"]["scale"]
+
+ # Layer Normalization
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
+
+ if split_mlp_wi:
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
+ else:
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
+
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
+
+ # Layer Normalization
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
+
+ # Assigning
+ flax_model_encoder_layer_block = flax_model.params["encoder"]["block"][str(layer_index)]["layer"]
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["k"]["kernel"] = t5x_attention_key
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["o"]["kernel"] = t5x_attention_out
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["q"]["kernel"] = t5x_attention_query
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["v"]["kernel"] = t5x_attention_value
+
+ flax_model_encoder_layer_block["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
+
+ # Global input layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ flax_model_encoder_layer_block["0"][encoder_attn_name]["global_input_layer_norm"][
+ "weight"
+ ] = t5x_global_layer_norm
+
+ if split_mlp_wi:
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
+ else:
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
+
+ flax_model_encoder_layer_block["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
+ flax_model_encoder_layer_block["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
+
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"] = flax_model_encoder_layer_block
+
+ # Only for layer 0:
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
+ "embedding"
+ ] = t5x_encoder_rel_embedding
+
+ # Side/global relative position_bias + layer norm
+ if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
+ t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
+ "embedding"
+ ] = t5x_encoder_global_rel_embedding
+
+ # Assigning
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
+
+ # Decoder
+ for layer_index in range(config.num_layers):
+ layer_name = f"layers_{str(layer_index)}"
+
+ # Self-Attention
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
+
+ # Layer Normalization
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][
+ "scale"
+ ]
+
+ # Encoder-Decoder-Attention
+ t5x_enc_dec_attention_module = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]
+ t5x_enc_dec_attention_key = t5x_enc_dec_attention_module["key"]["kernel"]
+ t5x_enc_dec_attention_out = t5x_enc_dec_attention_module["out"]["kernel"]
+ t5x_enc_dec_attention_query = t5x_enc_dec_attention_module["query"]["kernel"]
+ t5x_enc_dec_attention_value = t5x_enc_dec_attention_module["value"]["kernel"]
+
+ # Layer Normalization
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
+
+ # MLP
+ if split_mlp_wi:
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
+ else:
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
+
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
+
+ # Layer Normalization
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
+
+ # Assigning
+ flax_model_decoder_layer_block = flax_model.params["decoder"]["block"][str(layer_index)]["layer"]
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
+ flax_model_decoder_layer_block["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
+
+ flax_model_decoder_layer_block["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
+
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
+ flax_model_decoder_layer_block["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
+
+ flax_model_decoder_layer_block["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
+
+ if split_mlp_wi:
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
+ else:
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
+
+ flax_model_decoder_layer_block["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
+
+ flax_model_decoder_layer_block["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
+
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"] = flax_model_decoder_layer_block
+
+ # Decoder Normalization
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
+
+ # Only for layer 0:
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
+ "embedding"
+ ] = t5x_decoder_rel_embedding
+
+ # Token Embeddings
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
+
+ # LM Head (only in v1.1 and LongT5 checkpoints)
+ if "logits_dense" in t5x_model["target"]["decoder"]:
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
+
+ flax_model.save_pretrained(flax_dump_folder_path)
+ print("T5X Model was sucessfully converted!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
+ )
+ parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
+ parser.add_argument(
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
+ )
+ args = parser.parse_args()
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py
new file mode 100644
index 00000000000000..8ea0b38bb43094
--- /dev/null
+++ b/src/transformers/models/longt5/modeling_flax_longt5.py
@@ -0,0 +1,2401 @@
+# coding=utf-8
+# Copyright 2022 LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Flax LongT5 model."""
+
+
+import copy
+from typing import Any, Callable, List, Optional, Tuple
+
+import numpy as np
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxSeq2SeqLMOutput,
+ FlaxSeq2SeqModelOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_longt5 import LongT5Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
+_CONFIG_FOR_DOC = "LongT5Config"
+_TOKENIZER_FOR_DOC = "T5Tokenizer"
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = np.zeros_like(input_ids)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
+def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray:
+ """Pad an array so that a sequence length will be a multiple of `block_len`"""
+ pad_len = -x.shape[axis] % block_len
+ pad = [(0, 0)] * x.ndim
+ pad[axis] = (0, pad_len)
+ x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
+ return x
+
+
+def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray:
+ """Split an input array into blocks of a given `block_len` along the given `axis`. If the dimension length
+ is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
+ """
+ # pad tensor to multiple of block_len
+ if x.shape[axis] % block_len != 0:
+ x = _pad_to_multiple(x, block_len, axis, pad_value=0)
+ num_blocks = x.shape[axis] // block_len
+ output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1) :]
+ return x.reshape(output_shape)
+
+
+def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray:
+ """Concatenate three consecutive blocks for each input block for local attentiont.
+ For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
+ """
+ num_blocks = x.shape[block_axis]
+
+ pad = [(0, 0)] * x.ndim
+ pad[block_axis] = (1, 1)
+ # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
+ x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
+
+ blocks_list: List[np.array] = []
+ for i in range(3):
+ # We use indexing approach here:
+ # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
+ indices = [slice(0, None)] * x.ndim
+ indices[block_axis] = slice(i, i + num_blocks)
+ indices = tuple(indices)
+ blocks_list.append(x[indices])
+ return jnp.concatenate(blocks_list, axis=sequence_axis) # [batch_size, num_blocks, 3 * block_len, ...]
+
+
+def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:
+ """Makes 3-blocked relative position ids for local attention."""
+ position_ids = jnp.arange(3 * block_len, dtype=jnp.int32)
+ center_position_ids = position_ids[block_len:-block_len]
+ relative_position_ids = position_ids[None, :] - center_position_ids[:, None] # [block_len, 3 * block_len]
+ return relative_position_ids
+
+
+def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
+ """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
+ relative_position_ids = _make_3block_relative_position_ids(block_len)
+ locality_mask = jnp.abs(relative_position_ids) < block_len
+ locality_mask = locality_mask[None, None, :, :]
+ return jnp.logical_and(local_attention_mask, locality_mask)
+
+
+def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
+ """Prepare attention mask to be applied for a local attention."""
+ # [batch_size, num_blocks, block_len]
+ _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1)
+ # [batch_size, num_block, 3 * block_len]
+ _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2)
+
+ _blocked_attention_mask = _blocked_attention_mask[..., None]
+ _3blocked_attention_mask = _3blocked_attention_mask[..., None, :]
+ # [batch_size, num_block, block_len, 3 * block_len]
+ local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
+ local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
+ # [batch_size, 1, num_block, block_len, 3 * block_len]
+ return local_attention_mask[:, None, ...]
+
+
+def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]:
+ """Obtain the "fixed block" global id corresponding to each input token.
+
+ This implementation is a simlified version of the original Flaxformr implementation adopted from:
+ https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
+
+ In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
+ the whole fixed block, are assigned to the preceding block.
+
+ Padding tokens from the original sequence are represented by -1.
+ """
+ batch_size, seq_len = attention_mask.shape[:2]
+
+ def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray:
+ block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1
+ true_block_ends = jnp.logical_and(block_ends, block_ids >= 0)
+ full_blocks = true_block_ends.sum(-1)[..., None]
+ block_ids = jnp.minimum(block_ids, full_blocks - 1)
+ return block_ids
+
+ fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size
+ fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
+ mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0)
+ global_block_ids = jnp.maximum(
+ jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype)
+ )
+ # set padding tokens to -1
+ global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
+ # [batch_size, seq_len]
+ global_block_ids = handle_orphan_tokens(global_block_ids)
+ num_globals = seq_len // global_block_size
+
+ # [batch_size, seq_len // global_block_size]
+ if num_globals > 0:
+ _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1)
+ else:
+ _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype)
+ global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1
+ global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
+ return global_block_ids, global_segment_ids
+
+
+def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray:
+ """Create the relative position tensor for local -> global attention."""
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
+ global_seq_len = global_segment_ids.shape[-1]
+ global_positions = jnp.arange(global_seq_len)
+ side_relative_position = global_positions - block_ids[..., None]
+ return side_relative_position
+
+
+def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray:
+ """Compute individual block aggregates by summing over individual blocks."""
+ # (batch..., seq_len, global_seq_len))
+ one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len)
+ return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids)
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->LongT5
+class FlaxLongT5LayerNorm(nn.Module):
+ hidden_size: int
+ dtype: jnp.dtype = jnp.float32
+ eps: float = 1e-6
+ weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
+
+ def setup(self):
+ self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
+
+ def __call__(self, hidden_states):
+ """
+ Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean.
+ """
+ # layer norm should always be calculated in float32
+ variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
+ hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
+
+ return self.weight * hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->LongT5
+class FlaxLongT5DenseActDense(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
+ wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
+
+ self.wi = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wo = nn.Dense(
+ self.config.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wo_init_std),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->LongT5
+class FlaxLongT5DenseGatedActDense(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
+ wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
+
+ self.wi_0 = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wi_1 = nn.Dense(
+ self.config.d_ff,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wi_init_std),
+ dtype=self.dtype,
+ )
+ self.wo = nn.Dense(
+ self.config.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(wo_init_std),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
+
+ def __call__(self, hidden_states, deterministic):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->LongT5
+class FlaxLongT5LayerFF(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ if self.config.is_gated_act:
+ self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype)
+ else:
+ self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype)
+
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(self, hidden_states, deterministic=True):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
+ hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->LongT5
+class FlaxLongT5Attention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, query_length, key_length):
+ """Compute binned relative position bias"""
+ context_position = jnp.arange(query_length, dtype="i4")[:, None]
+ memory_position = jnp.arange(key_length, dtype="i4")[None, :]
+
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=(not self.causal),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, :, :, :]
+ return values
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions
+ # that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def _create_position_bias(
+ self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
+ ):
+ cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
+ key_length = key_states.shape[1]
+ query_length = key_length if cache_is_filled else query_states.shape[1]
+
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(query_length, key_length)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
+
+ # if key and values are already calculated, only the last query position bias should be taken
+ if cache_is_filled:
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ position_bias = jax.lax.dynamic_slice(
+ position_bias,
+ (0, 0, causal_attention_mask_shift, 0),
+ (1, self.n_heads, seq_length, max_decoder_length),
+ )
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ use_cache=False,
+ output_attentions=False,
+ deterministic=True,
+ init_cache=False,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ # for fast decoding causal attention mask should be shifted
+ causal_attention_mask_shift = (
+ self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0
+ )
+ # create causal attention_mask; attention_mask has to be defined when model is causal
+ if self.causal:
+ causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
+
+ # fast decoding for generate requires special attention_mask
+ if self.has_variable("cache", "cached_key"):
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_attention_mask = jax.lax.dynamic_slice(
+ causal_attention_mask,
+ (0, 0, causal_attention_mask_shift, 0),
+ (1, 1, seq_length, max_decoder_length),
+ )
+
+ # broadcast causal attention mask & attention mask to fit for merge
+ causal_attention_mask = jnp.broadcast_to(
+ causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
+ )
+ attention_mask = jnp.broadcast_to(
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape
+ )
+ attention_mask = combine_masks(attention_mask, causal_attention_mask)
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # replace masked positions with -10_000
+ if attention_mask is not None:
+ attention_mask = jax.lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
+ )
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(
+ key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
+ )
+
+ if attention_mask is not None:
+ position_bias = position_bias + attention_mask
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5LocalAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.local_radius = self.config.local_radius
+ self.block_len = self.local_radius + 1
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = jnp.arange(3 * block_length, dtype="i4")
+ context_position = memory_position[block_length:-block_length]
+
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, None, :, :, :]
+ return values
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
+
+ def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(block_len)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
+
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
+ query_states = _split_into_blocks(query_states, self.block_len, axis=1)
+ key_states = _split_into_blocks(key_states, self.block_len, axis=1)
+ value_states = _split_into_blocks(value_states, self.block_len, axis=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
+ value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ if attention_mask is not None:
+ attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
+
+ # replace masked positions with -10_000
+ attention_mask = jax.lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
+ )
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(self.block_len, attention_mask)
+
+ if attention_mask is not None:
+ position_bias = position_bias + attention_mask.swapaxes(1, 2)
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = attn_output[:, :seq_length, :]
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5TransientGlobalAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
+ self.relative_attention_max_distance = self.config.relative_attention_max_distance
+ self.d_model = self.config.d_model
+ self.key_value_proj_dim = self.config.d_kv
+ self.n_heads = self.config.num_heads
+ self.local_radius = self.config.local_radius
+ self.block_len = self.local_radius + 1
+ self.global_block_size = self.config.global_block_size
+ self.dropout = self.config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
+ kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+ o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
+
+ self.q = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(q_init_std),
+ dtype=self.dtype,
+ )
+ self.k = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.v = nn.Dense(
+ self.inner_dim,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(kv_init_std),
+ dtype=self.dtype,
+ )
+ self.o = nn.Dense(
+ self.d_model,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(o_init_std),
+ dtype=self.dtype,
+ )
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+
+ # Relativen attention bias & Layer norm for global attention
+ if self.has_relative_attention_bias:
+ self.global_relative_attention_bias = nn.Embed(
+ self.relative_attention_num_buckets,
+ self.n_heads,
+ embedding_init=jax.nn.initializers.normal(kv_init_std),
+ )
+ self.global_input_layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0) * num_buckets
+ relative_position = jnp.abs(relative_position)
+ else:
+ relative_position = -jnp.clip(relative_position, a_max=0)
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
+ )
+ relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
+
+ relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
+
+ return relative_buckets.astype("i4")
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = jnp.arange(3 * block_length, dtype="i4")
+ context_position = memory_position[block_length:-block_length]
+
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.transpose((2, 0, 1))[None, None, :, :, :]
+ return values
+
+ def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray:
+ # (batch_size, 1, 1, seq_len, global_seq_len)
+ side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
+ attention_side_bias = jax.lax.select(
+ side_attention_mask > 0,
+ jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype),
+ )
+ # (batch_size, seq_len, global_seq_len)
+ side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size)
+ side_relative_position_bucket = self._relative_position_bucket(
+ side_relative_position,
+ bidirectional=True,
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (batch_size, seq_len, global_seq_len, num_heads)
+ side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
+
+ # (batch_size, 1, num_heads, seq_len, global_seq_len)
+ side_bias = jnp.transpose(side_bias, (0, 3, 1, 2))
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ attention_side_bias = attention_side_bias + side_bias
+ return attention_side_bias
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
+
+ def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if self.has_relative_attention_bias:
+ position_bias = self.compute_bias(block_len)
+ elif attention_mask is not None:
+ position_bias = jnp.zeros_like(attention_mask)
+ else:
+ position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
+
+ return position_bias
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ key_value_states=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # Prepare components for transient-global attention
+ # Obtain block_ids and global_segment_ids
+ # global_seq_len := seq_len // self.global_block_size
+ # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(
+ attention_mask if attention_mask is not None else jnp.ones((batch_size, seq_length)),
+ self.global_block_size,
+ )
+ # Create global inputs
+ _global_seq_len = global_segment_ids.shape[-1]
+ global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
+ global_inputs = self.global_input_layer_norm(global_inputs)
+
+ # q, k, v projections
+ query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
+ key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
+ value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
+
+ # reshape to (batch_size, seq_length, n_heads, head_dim)
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # Get global/side key/value_states
+ side_key_states = self.k(global_inputs)
+ side_value_states = self.v(global_inputs)
+
+ # reshape to (batch_size, global_seq_len, n_heads, head_dim)
+ side_key_states = self._split_heads(side_key_states)
+ side_value_states = self._split_heads(side_value_states)
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
+ query_states = _split_into_blocks(query_states, self.block_len, axis=1)
+ key_states = _split_into_blocks(key_states, self.block_len, axis=1)
+ value_states = _split_into_blocks(value_states, self.block_len, axis=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
+ value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
+
+ # Tile side inputs across local key/value blocks
+ # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
+ reps = [1] * (side_key_states.ndim + 1)
+ reps[1] = key_states.shape[1]
+ side_key_states = jnp.tile(side_key_states[:, None, ...], reps)
+ side_value_states = jnp.tile(side_value_states[:, None, ...], reps)
+
+ # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
+ # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
+ key_states = jnp.concatenate((key_states, side_key_states), axis=2)
+ value_states = jnp.concatenate((value_states, side_value_states), axis=2)
+
+ # counter-act scaling in dot_product_attention_weights function
+ query_states *= jnp.sqrt(query_states.shape[-1])
+
+ if attention_mask is not None:
+ local_attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
+ local_attention_mask = jax.lax.select(
+ local_attention_mask > 0,
+ jnp.full(local_attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(local_attention_mask.shape, -1e10).astype(self.dtype),
+ )
+ else:
+ local_attention_mask = None
+
+ if position_bias is None:
+ # compute position bias (only for first layer)
+ position_bias = self._create_position_bias(self.block_len, attention_mask)
+ if local_attention_mask is not None:
+ position_bias = position_bias + local_attention_mask.swapaxes(1, 2)
+
+ # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
+ if attention_mask is None:
+ attention_mask = jnp.ones((batch_size, seq_length))
+ side_position_bias = self.compute_side_bias(attention_mask, global_segment_ids)
+ side_position_bias = _split_into_blocks(side_position_bias, self.block_len, axis=-2)
+ side_position_bias = jnp.swapaxes(side_position_bias, 1, 2)
+ position_bias = jnp.concatenate((position_bias, side_position_bias), axis=-1)
+
+ # create dropout rng
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ # Softmax(QK^T)
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=position_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ )
+
+ # multiply with value states
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+
+ # bring back to (batch_size, seq_length, d_model)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = attn_output[:, :seq_length, :]
+
+ # apply output matrix
+ attn_output = self.o(attn_output)
+
+ outputs = (attn_output, position_bias)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+
+ return outputs
+
+
+class FlaxLongT5LayerLocalSelfAttention(nn.Module):
+ """Local self attention used in encoder"""
+
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.LocalSelfAttention = FlaxLongT5LocalAttention(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ **kwargs: Any, # to accept init_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.LocalSelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module):
+ """Transient-Global self attention used in encoder"""
+
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ **kwargs: Any, # to accept init_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.TransientGlobalSelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->LongT5
+class FlaxLongT5LayerSelfAttention(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.SelfAttention = FlaxLongT5Attention(
+ self.config,
+ has_relative_attention_bias=self.has_relative_attention_bias,
+ causal=self.config.causal,
+ dtype=self.dtype,
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ init_cache=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->LongT5
+class FlaxLongT5LayerCrossAttention(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.EncDecAttention = FlaxLongT5Attention(
+ self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
+ )
+ self.layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ output_attentions=False,
+ deterministic=True,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ attention_mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class FlaxLongT5Block(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+ if self.causal:
+ attention_layer = FlaxLongT5LayerSelfAttention
+ elif self.config.encoder_attention_type == "local":
+ attention_layer = FlaxLongT5LayerLocalSelfAttention
+ elif self.config.encoder_attention_type == "transient-global":
+ attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention
+ else:
+ raise ValueError(
+ "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
+ f"but got {self.config.encoder_attention_type}."
+ )
+ self.layer = (
+ attention_layer(
+ self.config,
+ has_relative_attention_bias=self.has_relative_attention_bias,
+ name=str(0),
+ dtype=self.dtype,
+ ),
+ )
+ feed_forward_index = 1
+ if self.causal:
+ self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
+ feed_forward_index += 1
+
+ self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
+
+ # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ with T5->LongT5
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ output_attentions=False,
+ return_dict=True,
+ deterministic=True,
+ init_cache=False,
+ ):
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+ hidden_states = self_attention_outputs[0]
+ attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
+
+ do_cross_attention = self.causal and encoder_hidden_states is not None
+ if do_cross_attention:
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = cross_attention_outputs[0]
+
+ # Keep cross-attention outputs and relative position weights
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ outputs = outputs + attention_outputs
+
+ # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->LongT5
+class FlaxLongT5LayerCollection(nn.Module):
+ config: LongT5Config
+ has_relative_attention_bias: bool
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layer = FlaxLongT5Block(
+ self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
+ )
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ output_attentions=False,
+ return_dict=True,
+ deterministic=True,
+ init_cache=False,
+ ):
+ return self.layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->LongT5
+class FlaxLongT5BlockCollection(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+ self.blocks = [
+ FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
+ for i in range(self.config.num_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ ):
+ # Prepare head mask if needed
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.causal) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ for i, layer_module in enumerate(self.blocks):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[1]
+
+ if self.causal and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[2],)
+ if self.causal:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->LongT5
+class FlaxLongT5Stack(nn.Module):
+ config: LongT5Config
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.causal = self.config.causal
+
+ self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype)
+ self.final_layer_norm = FlaxLongT5LayerNorm(
+ self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
+ )
+ self.dropout = nn.Dropout(self.config.dropout_rate)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ ):
+ hidden_states = self.embed_tokens(input_ids)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ outputs = self.block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ hidden_states = outputs[0]
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+
+ # Add last layer
+ all_hidden_states = None
+
+ if output_hidden_states:
+ all_hidden_states = outputs.hidden_states
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ if output_hidden_states:
+ return (
+ hidden_states,
+ all_hidden_states,
+ ) + outputs[2:]
+ return (hidden_states,) + outputs[1:]
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+LONGT5_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+LONGT5_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ For training, `decoder_input_ids` should be provided.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+LONGT5_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
+ Training](./longt5#training).
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*):
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LongT5Config
+ base_model_prefix = "transformer"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: LongT5Config,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = jnp.ones_like(input_ids)
+ decoder_attention_mask = jnp.ones_like(input_ids)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: jnp.ndarray = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if decoder_input_ids is None:
+ raise ValueError(
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
+ " here."
+ )
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # prepare decoder inputs
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, **kwargs)
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+ >>> import jax.numpy as jnp
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxLongT5Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+LONGT5_START_DOCSTRING = r"""
+ The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
+ Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
+ Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
+ generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
+ efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5
+class FlaxLongT5Module(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def setup(self):
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
+ )
+
+ encoder_config = copy.deepcopy(self.config)
+ encoder_config.causal = False
+ self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
+
+ decoder_config = copy.deepcopy(self.config)
+ decoder_config.causal = True
+ decoder_config.num_layers = self.config.num_decoder_layers
+ self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ encoder_outputs=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ deterministic: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode if needed (training, first prediction pass)
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->LongT5
+class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
+ module_class = FlaxLongT5Module
+
+
+append_call_sample_docstring(
+ FlaxLongT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
+)
+
+FLAX_LONGT5_MODEL_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5Model
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base")
+
+ >>> input_ids = tokenizer(
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="np"
+ ... ).input_ids
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
+
+ >>> # forward pass
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+
+overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING)
+append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+
+
+@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
+# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->LongT5
+class FlaxLongT5ForConditionalGenerationModule(nn.Module):
+ config: LongT5Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def setup(self):
+ self.model_dim = self.config.d_model
+
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
+ )
+
+ encoder_config = copy.deepcopy(self.config)
+ encoder_config.causal = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype)
+
+ decoder_config = copy.deepcopy(self.config)
+ decoder_config.causal = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = self.config.num_decoder_layers
+ self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype)
+
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ encoder_outputs=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ deterministic: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Encode
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim**-0.5)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.shared.variables["params"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
+ else:
+ lm_logits = self.lm_head(sequence_output)
+
+ if not return_dict:
+ return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+
+ return FlaxSeq2SeqLMOutput(
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
+ module_class = FlaxLongT5ForConditionalGenerationModule
+
+ @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+ >>> import jax.numpy as jnp
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> text = "summarize: My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxLongT5Attention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
+ decoder_module = module._get_decoder_module()
+ decoder_outputs = decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = module.shared.variables["params"]["embedding"]
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
+ else:
+ lm_logits = module.lm_head(sequence_output)
+
+ return lm_logits, decoder_outputs
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ if past_key_values is None:
+ lm_logits, decoder_outputs = outputs
+ else:
+ (lm_logits, decoder_outputs), past = outputs
+
+ if return_dict:
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+ else:
+ outputs = (lm_logits,) + decoder_outputs[1:]
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jnp.DeviceArray] = None,
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
+ encoder_outputs=None,
+ **kwargs
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ extended_attention_mask = jax.lax.dynamic_update_slice(
+ extended_attention_mask, decoder_attention_mask, (0, 0)
+ )
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ return model_kwargs
+
+
+FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, FlaxLongT5ForConditionalGeneration
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
+
+ >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs["input_ids"]).sequences
+ >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+"""
+
+
+overwrite_call_docstring(
+ FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING
+)
+append_replace_return_docstrings(
+ FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
+)
diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py
new file mode 100644
index 00000000000000..cd6c91a7e8c1e7
--- /dev/null
+++ b/src/transformers/models/longt5/modeling_longt5.py
@@ -0,0 +1,2192 @@
+# coding=utf-8
+# Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LongT5 model."""
+
+
+import copy
+import math
+import warnings
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.utils.checkpoint import checkpoint
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ DUMMY_INPUTS,
+ DUMMY_MASK,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torch_fx_proxy,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_longt5 import LongT5Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LongT5Config"
+_TOKENIZER_FOR_DOC = "T5Tokenizer"
+_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
+
+# TODO: Update before the merge
+LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "google/long-t5-local-base",
+ "google/long-t5-local-large",
+ "google/long-t5-tglobal-base",
+ "google/long-t5-tglobal-large",
+]
+
+
+def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
+ """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
+ pad_len = -x.shape[dim] % block_len
+ # Handle cases when an empty input sequence is given
+ if not all(x.shape):
+ new_shape = list(x.shape)
+ new_shape[dim] += pad_len
+ return torch.zeros(new_shape, dtype=x.dtype)
+
+ pad = [(0, 0)] * x.ndim
+ pad[dim] = (0, pad_len)
+ pad = sum(pad[::-1], ())
+ x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
+ return x
+
+
+def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
+ """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
+ is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
+ """
+ # pad tensor to multiple of block_len
+ if x.shape[dim] % block_len != 0:
+ x = _pad_to_multiple(x, block_len, dim, pad_value=0)
+ num_blocks = x.shape[dim] // block_len
+ output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
+ # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
+ if 0 in output_shape:
+ return torch.empty(output_shape, dtype=x.dtype, device=x.device)
+ return x.reshape(output_shape)
+
+
+def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
+ """Concatenate three consecutive blocks for each input block for local attentiont.
+
+ For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
+ """
+ num_blocks = x.shape[block_dim]
+
+ pad = [(0, 0)] * x.ndim
+ pad[block_dim] = (1, 1)
+ pad = sum(pad[::-1], ())
+ # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
+ x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
+
+ blocks_list: List[torch.Tensor] = []
+ for i in range(3):
+ # We use indexing approach here:
+ # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
+ indices = [slice(0, None)] * x.ndim
+ indices[block_dim] = slice(i, i + num_blocks)
+ indices = tuple(indices)
+ blocks_list.append(x[indices])
+ # [batch_size, num_blocks, 3 * block_len, ...]
+ return torch.cat(blocks_list, dim=sequence_dim)
+
+
+def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
+ """Makes 3-blocked relative position ids for local attention."""
+ position_ids = torch.arange(3 * block_len, dtype=torch.int32)
+ center_position_ids = position_ids[block_len:-block_len]
+ # [block_len, 3 * block_len]
+ relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
+ return relative_position_ids
+
+
+def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
+ """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
+ relative_position_ids = _make_3block_relative_position_ids(block_len)
+ locality_mask = torch.abs(relative_position_ids) < block_len
+ locality_mask = locality_mask[None, None, :, :]
+ locality_mask = locality_mask.to(local_attention_mask.device)
+ return torch.logical_and(local_attention_mask, locality_mask)
+
+
+def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
+ """Prepare attention mask to be applied for a local attention."""
+ # [batch_size, num_blocks, block_len]
+ _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
+ # [batch_size, num_block, 3 * block_len]
+ _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
+
+ _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
+ _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
+ # [batch_size, num_block, block_len, 3 * block_len]
+ local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
+ local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
+ # [batch_size, 1, num_block, block_len, 3 * block_len]
+ return local_attention_mask.unsqueeze(1).to(device)
+
+
+def _make_global_fixed_block_ids(
+ attention_mask: torch.Tensor, global_block_size: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Obtain the "fixed block" global id corresponding to each input token.
+
+ This implementation is a simlified version of the original Flaxformr implementation adopted from:
+ https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
+
+ In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
+ the whole fixed block, are assigned to the preceding block.
+
+ Padding tokens from the original sequence are represented by -1.
+ """
+ batch_size, seq_len = attention_mask.shape[:2]
+
+ def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
+ block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
+ block_ends = block_ends.to(block_ids.device)
+ true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
+ full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
+ block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
+ return block_ids
+
+ fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
+ fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
+ mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
+ global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
+ _global_block_ids_lower_bound = torch.tensor(-1.0, dtype=global_block_ids.dtype, device=global_block_ids.device)
+ global_block_ids = torch.where(
+ global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
+ )
+ # set padding tokens to -1
+ global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
+ # [batch_size, seq_len]
+ global_block_ids = handle_orphan_tokens(global_block_ids)
+ num_globals = seq_len // global_block_size
+ # [batch_size, seq_len // global_block_size]
+ if num_globals > 0:
+ _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
+ else:
+ _sequence_block_ids_max = torch.zeros(
+ batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
+ )
+ global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
+ global_segment_ids = global_segment_ids.to(attention_mask.device)
+ global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
+ return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
+
+
+def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
+ """Create the relative position tensor for local -> global attention."""
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
+ global_seq_len = global_segment_ids.shape[-1]
+ global_positions = torch.arange(global_seq_len, device=block_ids.device)
+ side_relative_position = global_positions - block_ids[..., None]
+ return side_relative_position.type(torch.int64)
+
+
+def _create_global_aggregates(
+ hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
+) -> torch.Tensor:
+ """Compute individual block aggregates by summing over individual blocks."""
+ # (batch..., seq_len, global_seq_len))
+ block_ids = block_ids.where(
+ block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
+ )
+ one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
+ return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
+class LongT5LayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+
+ # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+try:
+ from apex.normalization import FusedRMSNorm
+
+ LongT5LayerNorm = FusedRMSNorm # noqa
+
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
+except ImportError:
+ # using the normal LongT5LayerNorm
+ pass
+except Exception:
+ logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
+ pass
+
+
+# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
+class LongT5DenseActDense(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.act = ACT2FN[config.dense_act_fn]
+
+ def forward(self, hidden_states):
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->LongT5
+class LongT5DenseGatedActDense(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.act = ACT2FN[config.dense_act_fn]
+
+ def forward(self, hidden_states):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
+class LongT5LayerFF(nn.Module):
+ def __init__(self, config: LongT5Config):
+ super().__init__()
+ if config.is_gated_act:
+ self.DenseReluDense = LongT5DenseGatedActDense(config)
+ else:
+ self.DenseReluDense = LongT5DenseActDense(config)
+
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, hidden_states):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states + self.dropout(forwarded_states)
+ return hidden_states
+
+
+# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
+class LongT5Attention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias=False):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length, device=None):
+ """Compute binned relative position bias"""
+ if device is None:
+ device = self.relative_attention_bias.weight.device
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
+ relative_position = memory_position - context_position # shape (query_length, key_length)
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # shape (query_length, key_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ assert (
+ len(past_key_value) == 2
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ def unshape(states):
+ """reshape"""
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
+ )
+ value_states = project(
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
+ )
+
+ # compute scores
+ scores = torch.matmul(
+ query_states, key_states.transpose(3, 2)
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+
+ if mask is not None:
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
+
+ scores += position_bias
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
+ scores
+ ) # (batch_size, n_heads, seq_length, key_length)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ ) # (batch_size, n_heads, seq_length, key_length)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class LongT5LocalAttention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = torch.arange(
+ 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
+ )
+ context_position = memory_position[block_length:-block_length]
+
+ # (block_length, 3 * block_length)
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # (block_length, 3 * block_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (block_length, 3 * block_length, num_heads)
+ values = self.relative_attention_bias(relative_position_bucket)
+ # (1, 1, num_heads, block_length, 3 * block_length)
+ values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.contiguous().view(batch_size, -1, self.inner_dim)
+
+ # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
+ query_states = shape(self.q(hidden_states))
+ key_states = shape(self.k(hidden_states))
+ value_states = shape(self.v(hidden_states))
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
+ query_states = _split_into_blocks(query_states, self.block_len, dim=1)
+ key_states = _split_into_blocks(key_states, self.block_len, dim=1)
+ value_states = _split_into_blocks(value_states, self.block_len, dim=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
+ value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
+
+ # Compute scores
+ scores = torch.einsum(
+ "...qhd,...khd->...hqk", query_states, key_states
+ ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
+
+ if position_bias is None:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(self.block_len)
+
+ if mask is not None:
+ # Replace masked positions with -1e10 (according to the original implementation)
+ mask = torch.where(mask > 0, 0.0, -1e10)
+ # We need to adjust position bias shape to be sum with mask
+ position_bias = position_bias + mask.transpose(1, 2)
+
+ scores += position_bias
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+ attn_weights = attn_weights.type(value_states.dtype)
+ attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
+ attn_output = attn_output[:, :seq_length, :]
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class LongT5TransientGlobalAttention(nn.Module):
+ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.relative_attention_max_distance = config.relative_attention_max_distance
+ self.d_model = config.d_model
+ self.key_value_proj_dim = config.d_kv
+ self.n_heads = config.num_heads
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+ self.global_block_size = config.global_block_size
+ self.dropout = config.dropout_rate
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+
+ # Mesh TensorFlow initialization to avoid scaling before softmax
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
+
+ # Relativen attention bias & Layer norm for global attention
+ if self.has_relative_attention_bias:
+ self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q = prune_linear_layer(self.q, index)
+ self.k = prune_linear_layer(self.k, index)
+ self.v = prune_linear_layer(self.v, index)
+ self.o = prune_linear_layer(self.o, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ @staticmethod
+ # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
+ return relative_buckets
+
+ def compute_bias(self, block_length: int):
+ """Compute binned relative position bias"""
+ memory_position = torch.arange(
+ 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
+ )
+ context_position = memory_position[block_length:-block_length]
+
+ # (block_length, 3 * block_length)
+ relative_position = memory_position[None, :] - context_position[:, None]
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # (block_length, 3 * block_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (block_length, 3 * block_length, num_heads)
+ values = self.relative_attention_bias(relative_position_bucket)
+ # (1, 1, num_heads, block_length, 3 * block_length)
+ values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
+ return values
+
+ def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
+ # (batch_size, 1, seq_len, global_seq_len)
+ side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
+ attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
+ # (batch_size, seq_len, global_seq_len)
+ side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
+ side_relative_position_bucket = self._relative_position_bucket(
+ side_relative_position,
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ # (batch_size, seq_len, global_seq_len, num_heads)
+ side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
+
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ side_bias = side_bias.permute([0, 3, 1, 2])
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ attention_side_bias = attention_side_bias + side_bias
+ return attention_side_bias
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ ):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.contiguous().view(batch_size, -1, self.inner_dim)
+
+ # Prepare components for transient-global attention
+ # Obtain block_ids and global_segment_ids
+ # global_seq_len := seq_len // self.global_block_size
+ # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
+ block_ids, global_segment_ids = _make_global_fixed_block_ids(
+ mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
+ self.global_block_size,
+ )
+ # Create global inputs
+ _global_seq_len = global_segment_ids.shape[-1]
+ global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
+ global_inputs = self.global_input_layer_norm(global_inputs)
+
+ # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
+ query_states = shape(self.q(hidden_states))
+ key_states = shape(self.k(hidden_states))
+ value_states = shape(self.v(hidden_states))
+ # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
+ side_key_states = shape(self.k(global_inputs))
+ side_value_states = shape(self.v(global_inputs))
+
+ # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
+ query_states = _split_into_blocks(query_states, self.block_len, dim=1)
+ key_states = _split_into_blocks(key_states, self.block_len, dim=1)
+ value_states = _split_into_blocks(value_states, self.block_len, dim=1)
+
+ # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
+ key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
+ value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
+
+ # Tile side inputs across local key/value blocks
+ # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
+ reps = [1] * (side_key_states.ndim + 1)
+ reps[1] = key_states.shape[1]
+ side_key_states = side_key_states.unsqueeze(1).repeat(reps)
+ side_value_states = side_value_states.unsqueeze(1).repeat(reps)
+
+ # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
+ # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
+ key_states = torch.cat([key_states, side_key_states], dim=2)
+ value_states = torch.cat([value_states, side_value_states], dim=2)
+
+ # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
+ scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
+
+ if mask is not None:
+ # We need to adjust position bias shape to be sum with mask
+ local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
+ # Replace masked positions with -10_000 (according to the original implementation)
+ local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
+ else:
+ local_attention_mask = None
+
+ if position_bias is None:
+ # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
+ device=scores.device,
+ dtype=scores.dtype,
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(self.block_len)
+
+ if local_attention_mask is not None:
+ # (batch_size, 1, n_heads, block_len, 3 * block_len)
+ position_bias = position_bias + local_attention_mask.transpose(1, 2)
+ position_bias = position_bias.type(scores.dtype)
+
+ # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
+ if mask is None:
+ mask = torch.ones(batch_size, seq_length)
+ # (batch_size, num_heads, seq_len, global_seq_len)
+ side_position_bias = self.compute_side_bias(mask, global_segment_ids)
+ # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
+ side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
+ side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
+ # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
+ position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
+
+ scores += position_bias
+ # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+ attn_weights = attn_weights.type(value_states.dtype)
+ attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
+ attn_output = attn_output[:, :seq_length, :]
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = None
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
+class LongT5LayerSelfAttention(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5LayerLocalSelfAttention(nn.Module):
+ """Local self attention used in encoder"""
+
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ **kwargs: Any, # to accept past_key_value and use_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.LocalSelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5LayerTransientGlobalSelfAttention(nn.Module):
+ """Transient-Global self attention used in encoder"""
+
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
+ config, has_relative_attention_bias=has_relative_attention_bias
+ )
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ output_attentions=False,
+ **kwargs: Any, # to accept past_key_value and use_cache kwargs
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.TransientGlobalSelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + self.dropout(attention_output[0])
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
+class LongT5LayerCrossAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False)
+ self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ )
+ layer_output = hidden_states + self.dropout(attention_output[0])
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+
+class LongT5Block(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.is_decoder = config.is_decoder
+ if config.is_decoder:
+ attention_layer = LongT5LayerSelfAttention
+ elif config.encoder_attention_type == "local":
+ attention_layer = LongT5LayerLocalSelfAttention
+ elif config.encoder_attention_type == "transient-global":
+ attention_layer = LongT5LayerTransientGlobalSelfAttention
+ else:
+ raise ValueError(
+ "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
+ f"but got {config.encoder_attention_type}."
+ )
+ self.layer = nn.ModuleList()
+ self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias))
+ if self.is_decoder:
+ self.layer.append(LongT5LayerCrossAttention(config))
+
+ self.layer.append(LongT5LayerFF(config))
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+
+ if past_key_value is not None:
+ if not self.is_decoder:
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+
+ if len(past_key_value) != expected_num_past_key_values:
+ raise ValueError(
+ f"There should be {expected_num_past_key_values} past states. "
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+ f"Got {len(past_key_value)} past key / value states"
+ )
+
+ self_attn_past_key_value = past_key_value[:2]
+ cross_attn_past_key_value = past_key_value[2:]
+ else:
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=self_attn_past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
+
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+ if present_key_value_state is not None:
+ query_length = present_key_value_state[0].shape[2]
+ else:
+ query_length = None
+
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ query_length=query_length,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = cross_attention_outputs[0]
+
+ # Combine self attn and cross attn key value states
+ if present_key_value_state is not None:
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
+
+ # Keep cross-attention outputs and relative position weights
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+
+ outputs = (hidden_states,)
+
+ if use_cache:
+ outputs = outputs + (present_key_value_state,) + attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+
+
+class LongT5PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = LongT5Config
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+
+ @property
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
+ def dummy_inputs(self):
+ input_ids = torch.tensor(DUMMY_INPUTS)
+ input_mask = torch.tensor(DUMMY_MASK)
+ dummy_inputs = {
+ "decoder_input_ids": input_ids,
+ "input_ids": input_ids,
+ "decoder_attention_mask": input_mask,
+ }
+ return dummy_inputs
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor # Used for testing weights initialization
+ if isinstance(module, LongT5LayerNorm):
+ module.weight.data.fill_(factor * 1.0)
+ elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
+ # Mesh TensorFlow embeddings initialization
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, LongT5DenseActDense):
+ # Mesh TensorFlow FF initialization
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
+ module.wi.bias.data.zero_()
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
+ module.wo.bias.data.zero_()
+ elif isinstance(module, LongT5DenseGatedActDense):
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
+ module.wi_0.bias.data.zero_()
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
+ module.wi_1.bias.data.zero_()
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
+ module.wo.bias.data.zero_()
+ elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
+ d_model = self.config.d_model
+ key_value_proj_dim = self.config.d_kv
+ n_heads = self.config.num_heads
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
+ if module.has_relative_attention_bias:
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
+ if isinstance(module, LongT5TransientGlobalAttention):
+ module.global_relative_attention_bias.weight.data.normal_(
+ mean=0.0, std=factor * ((d_model) ** -0.5)
+ )
+
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LongT5Attention, LongT5Stack)):
+ module.gradient_checkpointing = value
+
+ # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
+ def _shift_right(self, input_ids):
+ decoder_start_token_id = self.config.decoder_start_token_id
+ pad_token_id = self.config.pad_token_id
+
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the"
+ " pad_token_id. See LongT5 docs for more information"
+ )
+
+ # shift inputs to the right
+ if is_torch_fx_proxy(input_ids):
+ # Item assignment is not supported natively for proxies.
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
+ else:
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
+
+ return shifted_input_ids
+
+
+class LongT5Stack(LongT5PreTrainedModel):
+ def __init__(self, config, embed_tokens=None):
+ super().__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
+
+ self.local_radius = config.local_radius
+ self.block_len = self.local_radius + 1
+
+ self.block = nn.ModuleList(
+ [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+ )
+ self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ self.gradient_checkpointing = False
+
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+
+ if inputs_embeds is None:
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ if use_cache is True:
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
+
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
+ encoder_seq_length = encoder_hidden_states.shape[1]
+ encoder_attention_mask = torch.ones(
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
+ )
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
+ if self.is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, inputs_embeds.device
+ )
+ elif self.config.encoder_attention_type == "local":
+ extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
+ else: # we need to use both local attention mask and standard extended mask for transient-global attention
+ extended_attention_mask = attention_mask
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
+ # append next layer key value states
+ if use_cache:
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+ if self.is_decoder:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+LONGT5_START_DOCSTRING = r"""
+
+ The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
+ Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
+ Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
+ generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
+ efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+LONGT5_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
+ Training](./longt5#training).
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
+ `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+LONGT5_ENCODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
+ you should be able to pad the inputs on both the right and the left.
+
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for detail.
+
+ To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
+ Training](./longt5#training).
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+__HEAD_MASK_WARNING_MSG = """
+The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
+`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
+If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
+num_heads)`.
+"""
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+class LongT5Model(LongT5PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = [
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = LongT5Stack(decoder_config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import T5Tokenizer, LongT5Model
+
+ >>> tokenizer = T5Tokenizer.from_pretrained("google/long-t5-local-base")
+ >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
+
+ >>> # Let's try a very long encoder input.
+ >>> input_ids = tokenizer(
+ ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
+
+ >>> # forward pass
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
+class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"lm_head.weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = [
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.model_dim = config.d_model
+
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = LongT5Stack(decoder_config, self.shared)
+
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
+ labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
+ >>> model = LongT5ForConditionalGeneration.from_pretrained(
+ ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
+ ... )
+
+ >>> # Let's try a very long input.
+ >>> input_ids = tokenizer(
+ ... "summarize: " + 100 * "studies have shown that owning a dog is good for you ", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+
+ >>> outputs = model.generate(input_ids)
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ abstractthe aim of this article is to summarize the studies have shown that owning a dog
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ # Convert encoder inputs in embeddings if needed
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim**-0.5)
+
+ lm_logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs
+ ):
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "decoder_input_ids": input_ids,
+ "past_key_values": past,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self._shift_right(labels)
+
+ def _reorder_cache(self, past, beam_idx):
+ # if decoder past is not included in output
+ # speedy decoding is disabled and no need to reorder
+ if past is None:
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
+ return past
+
+ reordered_decoder_past = ()
+ for layer_past_states in past:
+ # get the correct batch idx from layer past batch dim
+ # batch dim of `past` is at 2nd position
+ reordered_layer_past_states = ()
+ for layer_past_state in layer_past_states:
+ # need to set correct `past` for each of the four key / value states
+ reordered_layer_past_states = reordered_layer_past_states + (
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
+ )
+
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
+ assert len(reordered_layer_past_states) == len(layer_past_states)
+
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
+ return reordered_decoder_past
+
+
+@add_start_docstrings(
+ "The bare LONGT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
+ LONGT5_START_DOCSTRING,
+)
+class LongT5EncoderModel(LongT5PreTrainedModel):
+ authorized_missing_keys = [
+ r"encoder.embed_tokens.weight",
+ ]
+
+ def __init__(self, config: LongT5Config):
+ super().__init__(config)
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = LongT5Stack(encoder_config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.set_input_embeddings(new_embeddings)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
+ >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
+ >>> input_ids = tokenizer(
+ ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+ >>> outputs = model(input_ids=input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return encoder_outputs
diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py
index d18d016b5026c7..36ca833aaab6d3 100644
--- a/src/transformers/models/luke/__init__.py
+++ b/src/transformers/models/luke/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_luke": ["LukeTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_luke"] = [
"LUKE_PRETRAINED_MODEL_ARCHIVE_LIST",
"LukeForEntityClassification",
@@ -42,7 +47,12 @@
from .configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
from .tokenization_luke import LukeTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_luke import (
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST,
LukeForEntityClassification,
diff --git a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
index 520ae61b43ece9..d2b2323b289c70 100644
--- a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
@@ -77,13 +77,17 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
raise ValueError(f"Missing keys {', '.join(missing_keys)}. Expected only missing embeddings.position_ids")
if not (all(key.startswith("entity_predictions") or key.startswith("lm_head") for key in unexpected_keys)):
raise ValueError(
- f"Unexpected keys {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
+ "Unexpected keys"
+ f" {', '.join([key for key in unexpected_keys if not (key.startswith('entity_predictions') or key.startswith('lm_head'))])}"
)
# Check outputs
tokenizer = LukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the"
+ " new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
@@ -116,7 +120,8 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
if not (outputs.entity_last_hidden_state.shape != expected_shape):
raise ValueError(
- f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
+ f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
+ f" {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
@@ -129,7 +134,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
def load_entity_vocab(entity_vocab_path):
entity_vocab = {}
with open(entity_vocab_path, "r", encoding="utf-8") as f:
- for (index, line) in enumerate(f):
+ for index, line in enumerate(f):
title, _ = line.rstrip().split("\t")
entity_vocab[title] = index
diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py
index cd5a53ddae305b..4c2491aee73096 100644
--- a/src/transformers/models/luke/modeling_luke.py
+++ b/src/transformers/models/luke/modeling_luke.py
@@ -874,7 +874,8 @@ def _set_gradient_checkpointing(self, module, value=False):
@add_start_docstrings(
- "The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any specific head on top.",
+ "The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any"
+ " specific head on top.",
LUKE_START_DOCSTRING,
)
class LukeModel(LukePreTrainedModel):
@@ -953,11 +954,11 @@ def forward(
>>> entities = [
... "BeyoncƩ",
... "Los Angeles",
- >>> ] # Wikipedia entity titles corresponding to the entity mentions "BeyoncƩ" and "Los Angeles"
+ ... ] # Wikipedia entity titles corresponding to the entity mentions "BeyoncƩ" and "Los Angeles"
>>> entity_spans = [
... (0, 7),
... (17, 28),
- >>> ] # character-based entity spans corresponding to "BeyoncƩ" and "Los Angeles"
+ ... ] # character-based entity spans corresponding to "BeyoncƩ" and "Los Angeles"
>>> encoding = tokenizer(
... text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt"
@@ -1228,13 +1229,15 @@ def forward(
loss = mlm_loss
mep_loss = None
- entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
- if entity_labels is not None:
- mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
- if loss is None:
- loss = mep_loss
- else:
- loss = loss + mep_loss
+ entity_logits = None
+ if outputs.entity_last_hidden_state is not None:
+ entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
+ if entity_labels is not None:
+ mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
+ if loss is None:
+ loss = mep_loss
+ else:
+ loss = loss + mep_loss
if not return_dict:
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
@@ -1435,7 +1438,7 @@ def forward(
>>> entity_spans = [
... (0, 7),
... (17, 28),
- >>> ] # character-based entity spans corresponding to "BeyoncƩ" and "Los Angeles"
+ ... ] # character-based entity spans corresponding to "BeyoncƩ" and "Los Angeles"
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py
index e35db36aedee61..3cbc9218c0f9a2 100644
--- a/src/transformers/models/luke/tokenization_luke.py
+++ b/src/transformers/models/luke/tokenization_luke.py
@@ -17,6 +17,7 @@
import itertools
import json
import os
+from collections.abc import Mapping
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
@@ -252,7 +253,8 @@ def __init__(
self.max_entity_length = 2
else:
raise ValueError(
- f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only."
+ f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification',"
+ " 'entity_span_classification'] only."
)
self.max_mention_length = max_mention_length
@@ -597,7 +599,7 @@ def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spa
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
- "entity_spans should be given as a list of tuples " "containing the start and end character indices"
+ "entity_spans should be given as a list of tuples containing the start and end character indices"
)
if entities is not None:
@@ -1006,7 +1008,8 @@ def prepare_for_model(
if num_invalid_entities != 0:
logger.warning(
- f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens"
+ f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
+ " truncation of input tokens"
)
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
@@ -1031,7 +1034,7 @@ def prepare_for_model(
entity_position_ids = []
entity_start_positions = []
entity_end_positions = []
- for (token_spans, offset) in (
+ for token_spans, offset in (
(valid_entity_token_spans, entity_token_offset),
(valid_pair_entity_token_spans, pair_entity_token_offset),
):
@@ -1140,7 +1143,7 @@ def pad(
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
@@ -1180,7 +1183,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -1383,6 +1386,6 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
+ f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return vocab_file, merge_file, entity_vocab_file
diff --git a/src/transformers/models/lxmert/__init__.py b/src/transformers/models/lxmert/__init__.py
index 38d9d5e67e9f23..0b8b58bc998637 100644
--- a/src/transformers/models/lxmert/__init__.py
+++ b/src/transformers/models/lxmert/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_lxmert": ["LxmertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_lxmert_fast"] = ["LxmertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_lxmert"] = [
"LxmertEncoder",
"LxmertForPreTraining",
@@ -40,7 +56,12 @@
"LxmertXLayer",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_lxmert"] = [
"TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLxmertForPreTraining",
@@ -55,10 +76,20 @@
from .configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from .tokenization_lxmert import LxmertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_lxmert_fast import LxmertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_lxmert import (
LxmertEncoder,
LxmertForPreTraining,
@@ -69,7 +100,12 @@
LxmertXLayer,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_lxmert import (
TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLxmertForPreTraining,
diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
index 7debd71af3b39c..f8eb86f1d1e48a 100755
--- a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained model. \n"
- "This specifies the model architecture.",
+ help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py
index c9b2541251e855..34b90441fede21 100644
--- a/src/transformers/models/lxmert/modeling_lxmert.py
+++ b/src/transformers/models/lxmert/modeling_lxmert.py
@@ -336,7 +336,7 @@ def transpose_for_scores(self, x):
self.num_attention_heads,
self.attention_head_size,
)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
@@ -365,7 +365,7 @@ def forward(self, hidden_states, context, attention_mask=None, output_attentions
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
@@ -1193,7 +1193,8 @@ def forward(
if "masked_lm_labels" in kwargs:
warnings.warn(
- "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels`"
+ " instead.",
FutureWarning,
)
labels = kwargs.pop("masked_lm_labels")
@@ -1252,7 +1253,7 @@ def forward(
visual_prediction_scores = visual_prediction_scores_dict[key]
visual_loss = visual_loss_fct(
visual_prediction_scores.view(-1, output_dim),
- label.view(*label_shape),
+ label.view(label_shape),
)
if visual_loss.dim() > 1: # Regression Losses
visual_loss = visual_loss.mean(1)
diff --git a/src/transformers/models/lxmert/tokenization_lxmert_fast.py b/src/transformers/models/lxmert/tokenization_lxmert_fast.py
index 9e88bc1581cb0e..8cfa20a9a26f7f 100644
--- a/src/transformers/models/lxmert/tokenization_lxmert_fast.py
+++ b/src/transformers/models/lxmert/tokenization_lxmert_fast.py
@@ -24,7 +24,9 @@
"unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/vocab.txt",
},
"tokenizer_file": {
- "unc-nlp/lxmert-base-uncased": "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json",
+ "unc-nlp/lxmert-base-uncased": (
+ "https://huggingface.co/unc-nlp/lxmert-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/m2m_100/__init__.py b/src/transformers/models/m2m_100/__init__.py
index 81d664d0f79ba5..23b7e2a46cbe5e 100644
--- a/src/transformers/models/m2m_100/__init__.py
+++ b/src/transformers/models/m2m_100/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_m2m_100"] = [
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
@@ -39,7 +44,12 @@
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100OnnxConfig
from .tokenization_m2m_100 import M2M100Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 36539736bf84e9..90de52d4c351ed 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -101,7 +101,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -288,7 +288,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -304,7 +305,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -325,7 +327,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -565,7 +568,7 @@ def _set_gradient_checkpointing(self, module, value=False):
"""
M2M_100_GENERATION_EXAMPLE = r"""
- Translation example::
+ Translation example:
```python
>>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
@@ -793,7 +796,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
@@ -994,7 +998,7 @@ def forward(
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1025,7 +1029,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
@@ -1046,7 +1051,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
+ " `use_cache=False`..."
)
use_cache = False
@@ -1235,9 +1241,9 @@ def forward(
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
@@ -1299,22 +1305,7 @@ def forward(
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
-
- Example:
-
- ```python
- >>> from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
-
- >>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
- >>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
-
- >>> text_to_translate = "Life is like a box of chocolates"
- >>> model_inputs = tokenizer(text_to_translate, return_tensors="pt")
-
- >>> # translate to French
- >>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("fr"))
- >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True))
- ```"""
+ """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py
index 5971d2d5743bc6..eaaaf290821bc5 100644
--- a/src/transformers/models/marian/__init__.py
+++ b/src/transformers/models/marian/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -31,10 +32,20 @@
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_marian"] = ["MarianTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_marian"] = [
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
"MarianForCausalLM",
@@ -43,19 +54,39 @@
"MarianPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_marian import MarianTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_marian import (
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST,
MarianForCausalLM,
@@ -64,10 +95,20 @@
MarianPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
else:
diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py
index 835b317f9d9240..f662d388448bb4 100644
--- a/src/transformers/models/marian/configuration_marian.py
+++ b/src/transformers/models/marian/configuration_marian.py
@@ -327,8 +327,9 @@ def _generate_dummy_inputs_for_causal_lm(
self._config.hidden_size // num_encoder_attention_heads,
)
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py
index bd8490cb2d6204..1fb5a34f064fd3 100644
--- a/src/transformers/models/marian/convert_marian_to_pytorch.py
+++ b/src/transformers/models/marian/convert_marian_to_pytorch.py
@@ -140,17 +140,21 @@ def find_model_file(dest_dir): # this one better
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
- "opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
- "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
- "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
+ "opus-mt-en-ROMANCE": (
+ "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
+ "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
+ "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
+ ),
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
- "opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
- "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
- "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
+ "opus-mt-ROMANCE-en": (
+ "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
+ "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
+ "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en"
+ ),
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py
index 8fea39e19aebab..da2e4a1fe5b51f 100644
--- a/src/transformers/models/marian/modeling_flax_marian.py
+++ b/src/transformers/models/marian/modeling_flax_marian.py
@@ -551,7 +551,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index 65a471d6417cac..0dc30ed0b476c7 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -103,7 +103,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -853,7 +856,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -993,9 +996,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
@@ -1268,9 +1272,9 @@ class MarianMTModel(MarianPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
r"embed_positions",
]
diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py
index 04a24ac9f9f1ff..d5f41abe13378b 100644
--- a/src/transformers/models/marian/modeling_tf_marian.py
+++ b/src/transformers/models/marian/modeling_tf_marian.py
@@ -267,7 +267,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -277,7 +280,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -293,7 +299,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -310,7 +319,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -784,7 +796,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -983,7 +998,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py
index 3579d5dffa1807..62f145e7b79820 100644
--- a/src/transformers/models/marian/tokenization_marian.py
+++ b/src/transformers/models/marian/tokenization_marian.py
@@ -47,7 +47,9 @@
"Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json"
},
"tokenizer_config_file": {
- "Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/tokenizer_config.json"
+ "Helsinki-NLP/opus-mt-en-de": (
+ "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/tokenizer_config.json"
+ )
},
}
diff --git a/src/transformers/models/maskformer/__init__.py b/src/transformers/models/maskformer/__init__.py
index 2f15ed34f0c2a1..4234f76dc565cc 100644
--- a/src/transformers/models/maskformer/__init__.py
+++ b/src/transformers/models/maskformer/__init__.py
@@ -17,18 +17,26 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
-}
+_import_structure = {"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_maskformer"] = [
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
@@ -39,9 +47,19 @@
if TYPE_CHECKING:
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_maskformer import MaskFormerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
diff --git a/src/transformers/models/maskformer/configuration_maskformer.py b/src/transformers/models/maskformer/configuration_maskformer.py
index 50ad6880adb288..ab68de3f0453cf 100644
--- a/src/transformers/models/maskformer/configuration_maskformer.py
+++ b/src/transformers/models/maskformer/configuration_maskformer.py
@@ -24,7 +24,9 @@
MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/maskformer-swin-base-ade": "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
+ "facebook/maskformer-swin-base-ade": (
+ "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
+ )
# See all MaskFormer models at https://huggingface.co/models?filter=maskformer
}
@@ -130,7 +132,8 @@ def __init__(
backbone_model_type = backbone_config.pop("model_type")
if backbone_model_type not in self.backbones_supported:
raise ValueError(
- f"Backbone {backbone_model_type} not supported, please use one of {','.join(self.backbones_supported)}"
+ f"Backbone {backbone_model_type} not supported, please use one of"
+ f" {','.join(self.backbones_supported)}"
)
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
@@ -141,7 +144,8 @@ def __init__(
decoder_type = decoder_config.pop("model_type")
if decoder_type not in self.decoders_supported:
raise ValueError(
- f"Transformer Decoder {decoder_type} not supported, please use one of {','.join(self.decoders_supported)}"
+ f"Transformer Decoder {decoder_type} not supported, please use one of"
+ f" {','.join(self.decoders_supported)}"
)
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
diff --git a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
index 045d2bc0f515d0..c08591e044db9e 100644
--- a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
@@ -188,7 +188,7 @@ def __init__(self, original_model: nn.Module, config: MaskFormerConfig):
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
- for (src_key, dst_key) in renamed_keys:
+ for src_key, dst_key in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
def replace_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: MaskFormerConfig):
@@ -643,12 +643,18 @@ def get_name(checkpoint_file: Path):
parser.add_argument(
"--checkpoints_dir",
type=Path,
- help="A directory containing the model's checkpoints. The directory has to have the following structure: //.pkl",
+ help=(
+ "A directory containing the model's checkpoints. The directory has to have the following structure:"
+ " //.pkl"
+ ),
)
parser.add_argument(
"--configs_dir",
type=Path,
- help="A directory containing the model's configs, see detectron2 doc. The directory has to have the following structure: //.yaml",
+ help=(
+ "A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
+ " structure: //.yaml"
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
@@ -660,7 +666,10 @@ def get_name(checkpoint_file: Path):
"--maskformer_dir",
required=True,
type=Path,
- help="A path to MaskFormer's original implementation directory. You can download from here: https://github.com/facebookresearch/MaskFormer",
+ help=(
+ "A path to MaskFormer's original implementation directory. You can download from here:"
+ " https://github.com/facebookresearch/MaskFormer"
+ ),
)
args = parser.parse_args()
diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py
index 5e466f2ddb07e3..3a5fd49d80fa77 100644
--- a/src/transformers/models/maskformer/feature_extraction_maskformer.py
+++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py
@@ -253,8 +253,9 @@ def __call__(
if not valid_segmentation_maps:
raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
+ " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
+ " examples)."
)
is_batched = bool(
@@ -591,7 +592,7 @@ def post_process_panoptic_segmentation(
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
# now, we need to iterate over the batch size to correctly process the segmentation we got from the queries using our thresholds. Even if the original predicted masks have the same shape across the batch, they won't after thresholding so batch-wise operations are impossible
results: List[Dict[str, Tensor]] = []
- for (mask_probs, pred_scores, pred_labels) in zip(mask_probs, pred_scores, pred_labels):
+ for mask_probs, pred_scores, pred_labels in zip(mask_probs, pred_scores, pred_labels):
mask_probs, pred_scores, pred_labels = self.remove_low_and_no_objects(
mask_probs, pred_scores, pred_labels, object_mask_threshold, num_labels
)
diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py
index 339de6eeeb10b3..64c8d0029cfbcb 100644
--- a/src/transformers/models/maskformer/modeling_maskformer.py
+++ b/src/transformers/models/maskformer/modeling_maskformer.py
@@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
- batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
@@ -664,7 +664,7 @@ def __init__(self, config, dim, num_heads):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
- f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})"
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
@@ -697,16 +697,16 @@ def __init__(self, config, dim, num_heads):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
@@ -750,7 +750,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -764,7 +764,7 @@ def __init__(self, config, dim):
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
@@ -797,7 +797,13 @@ def prune_heads(self, heads):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
- def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@@ -814,7 +820,7 @@ def __init__(self, config, dim):
else:
self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
@@ -827,7 +833,7 @@ def __init__(self, config, dim):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
@@ -1188,7 +1194,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
@@ -1252,7 +1259,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -1281,7 +1289,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -1949,7 +1958,7 @@ def outputs_shapes(self) -> List[int]:
return [layer.dim for layer in self.model.encoder.layers]
-class MaskFormerFPNConvLayer(nn.Sequential):
+class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
A basic module that executes conv - norm - in sequence used in MaskFormer.
@@ -1960,11 +1969,26 @@ def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, pa
out_features (`int`):
The number of outputs features (channels).
"""
- super().__init__(
+ super().__init__()
+ self.layers = [
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True),
- )
+ ]
+ for i, layer in enumerate(self.layers):
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
class MaskFormerFPNLayer(nn.Module):
@@ -2092,7 +2116,22 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
return pos
-class MaskformerMLPPredictionHead(nn.Sequential):
+class PredictionBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
+ super().__init__()
+ self.layers = [nn.Linear(in_dim, out_dim), activation]
+ # Maintain submodule indexing as if part of a Sequential block
+ for i, layer in enumerate(self.layers):
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
+
+class MaskformerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
@@ -2107,18 +2146,28 @@ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers:
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
+ super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
- layers = []
+ self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
-
- layer = nn.Sequential(
- nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity()
- )
- layers.append(layer)
-
- super().__init__(*layers)
+ activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
+ layer = PredictionBlock(in_dim, out_dim, activation=activation)
+ self.layers.append(layer)
+ # Provide backwards compatibility from when the class inherited from nn.Sequential
+ # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
+ # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
+ # self.my_layer_name = Layer()
+ # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
+ # explicitly
+ self.add_module(str(i), layer)
+
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
class MaskFormerPixelLevelModule(nn.Module):
@@ -2244,20 +2293,21 @@ def _init_weights(self, module: nn.Module):
nn.init.constant_(module.input_projection.bias, 0)
# FPN
elif isinstance(module, MaskFormerFPNModel):
- nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std)
+ nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNLayer):
nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNConvLayer):
- nn.init.xavier_uniform_(module[0].weight, gain=xavier_std)
+ nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std)
# The MLP head
elif isinstance(module, MaskformerMLPPredictionHead):
# I was not able to find the correct initializer in the original implementation
# we'll use xavier
- for layer in module:
- nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std)
- nn.init.constant_(layer[0].bias, 0)
+ for submodule in module.modules():
+ if isinstance(submodule, nn.Linear):
+ nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
+ nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py
index 294eb15f0366d3..ef967c2482a19c 100644
--- a/src/transformers/models/mbart/__init__.py
+++ b/src/transformers/models/mbart/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
-}
+_import_structure = {"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mbart"] = [
"MBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"MBartForCausalLM",
@@ -48,14 +62,24 @@
"MBartPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_mbart"] = [
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
@@ -68,13 +92,28 @@
if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart import MBartTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart_fast import MBartTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mbart import (
MBART_PRETRAINED_MODEL_ARCHIVE_LIST,
MBartForCausalLM,
@@ -85,10 +124,20 @@
MBartPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py
index e4da61442d16d4..af67cf858db177 100644
--- a/src/transformers/models/mbart/configuration_mbart.py
+++ b/src/transformers/models/mbart/configuration_mbart.py
@@ -322,8 +322,9 @@ def _generate_dummy_inputs_for_causal_lm(
self._config.hidden_size // num_encoder_attention_heads,
)
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py
index 141d2b10415eff..7cb52033b78a6f 100644
--- a/src/transformers/models/mbart/modeling_flax_mbart.py
+++ b/src/transformers/models/mbart/modeling_flax_mbart.py
@@ -550,7 +550,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 78d094922ba1af..d342d5fcbf3b4f 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -119,7 +119,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
@@ -236,7 +236,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -252,7 +253,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -273,7 +275,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -808,7 +811,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -905,7 +909,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1048,7 +1052,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -1258,9 +1263,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: MBartConfig):
diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py
index b31ac1bd635d95..fa19d711a311fe 100644
--- a/src/transformers/models/mbart/modeling_tf_mbart.py
+++ b/src/transformers/models/mbart/modeling_tf_mbart.py
@@ -229,7 +229,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -239,7 +242,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -255,7 +261,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -272,7 +281,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -763,7 +775,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -969,7 +984,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1300,7 +1318,7 @@ def call(
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
- tf.fill(shape_list(labels), -100),
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False
diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py
index d6ea6260aec11e..2517dfb584bb23 100644
--- a/src/transformers/models/mbart/tokenization_mbart.py
+++ b/src/transformers/models/mbart/tokenization_mbart.py
@@ -32,8 +32,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model",
- "facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-en-ro": (
+ "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model"
+ ),
+ "facebook/mbart-large-cc25": (
+ "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py
index a172d37913a4be..52902e3a40f082 100644
--- a/src/transformers/models/mbart/tokenization_mbart_fast.py
+++ b/src/transformers/models/mbart/tokenization_mbart_fast.py
@@ -38,8 +38,12 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model",
- "facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-en-ro": (
+ "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/sentencepiece.bpe.model"
+ ),
+ "facebook/mbart-large-cc25": (
+ "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
"facebook/mbart-large-en-ro": "https://huggingface.co/facebook/mbart-large-en-ro/resolve/main/tokenizer.json",
diff --git a/src/transformers/models/mbart50/__init__.py b/src/transformers/models/mbart50/__init__.py
index ee0edc94dfb448..299c0d0da7bbc9 100644
--- a/src/transformers/models/mbart50/__init__.py
+++ b/src/transformers/models/mbart50/__init__.py
@@ -17,23 +17,43 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart50 import MBart50Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mbart50_fast import MBart50TokenizerFast
else:
diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py
index c7e53c61495b07..145a546c181044 100644
--- a/src/transformers/models/mbart50/tokenization_mbart50.py
+++ b/src/transformers/models/mbart50/tokenization_mbart50.py
@@ -32,7 +32,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
diff --git a/src/transformers/models/mbart50/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py
index 97e2584a0d003a..28fb726c476d8b 100644
--- a/src/transformers/models/mbart50/tokenization_mbart50_fast.py
+++ b/src/transformers/models/mbart50/tokenization_mbart50_fast.py
@@ -37,10 +37,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
- "facebook/mbart-large-50-one-to-many-mmt": "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json",
+ "facebook/mbart-large-50-one-to-many-mmt": (
+ "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/mctct/__init__.py b/src/transformers/models/mctct/__init__.py
new file mode 100644
index 00000000000000..6c28eb2214c56e
--- /dev/null
+++ b/src/transformers/models/mctct/__init__.py
@@ -0,0 +1,75 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig"],
+ "processing_mctct": ["MCTCTProcessor"],
+}
+
+
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_mctct"] = ["MCTCTFeatureExtractor"]
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_mctct"] = [
+ "MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MCTCTForCTC",
+ "MCTCTModel",
+ "MCTCTPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig
+ from .processing_mctct import MCTCTProcessor
+
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_mctct import MCTCTFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_mctct import MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST, MCTCTForCTC, MCTCTModel, MCTCTPreTrainedModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/mctct/configuration_mctct.py b/src/transformers/models/mctct/configuration_mctct.py
new file mode 100644
index 00000000000000..f71467e65dae3f
--- /dev/null
+++ b/src/transformers/models/mctct/configuration_mctct.py
@@ -0,0 +1,185 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""M-CTC-T model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "speechbrain/m-ctc-t-large": "https://huggingface.co/speechbrain/m-ctc-t-large/resolve/main/config.json",
+ # See all M-CTC-T models at https://huggingface.co/models?filter=mctct
+}
+
+
+class MCTCTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an
+ M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the M-CTC-T
+ [speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 8065):
+ Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MCTCTModel`].
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 36):
+ Number of hidden layers in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ attention_head_dim (`int`, *optional*, defaults to 384):
+ Dimensions of each attention head for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 920):
+ The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ layerdrop (`float`, *optional*, defaults to 0.3):
+ The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
+ implementation.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ The tokenizer index of the pad token.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The tokenizer index of the bos token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The tokenizer index of the eos token.
+ conv_glu_dim (`int`, *optional*, defaults to 1):
+ The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original
+ Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences.
+ conv_dropout (`int`, *optional*, defaults to 0.3):
+ The probability of randomly dropping the `Conv1dSubsampler` layer during training.
+ num_conv_layers (`int`, *optional*, defaults to 1):
+ Number of convolution layers before applying transformer encoder layers.
+ conv_kernel (`List[int]`, *optional*, defaults to `[7]`):
+ The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
+ to `num_conv_layers`.
+ conv_stride (`List[int]`, *optional*, defaults to `[3]`):
+ The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
+ to `num_conv_layers`.
+ input_feat_per_channel (`int`, *optional*, defaults to 80):
+ Feature dimensions of the channels of the input to the Conv1D layer.
+ input_channels (`int`, *optional*, defaults to 1):
+ Number of input channels of the input to the Conv1D layer.
+ conv_channels (`List[int]`, *optional*, defaults to None):
+ Channel sizes of intermediate Conv1D layers.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`MCTCTForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`MCTCTForCTC`].
+
+ Example:
+
+ ```python
+ >>> from transformers import MCTCTModel, MCTCTConfig
+
+ >>> # Initializing a M-CTC-T mctct-large style configuration
+ >>> configuration = MCTCTConfig()
+
+ >>> # Initializing a model from the mctct-large style configuration
+ >>> model = MCTCTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "mctct"
+
+ def __init__(
+ self,
+ vocab_size=8065,
+ hidden_size=1536,
+ num_hidden_layers=36,
+ intermediate_size=6144,
+ num_attention_heads=4,
+ attention_head_dim=384,
+ max_position_embeddings=920,
+ layer_norm_eps=1e-5,
+ layerdrop=0.3,
+ hidden_act="relu",
+ initializer_range=0.02,
+ hidden_dropout_prob=0.3,
+ attention_probs_dropout_prob=0.3,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ conv_glu_dim=1,
+ conv_dropout=0.3,
+ num_conv_layers=1,
+ conv_kernel=(7,),
+ conv_stride=(3,),
+ input_feat_per_channel=80,
+ input_channels=1,
+ conv_channels=None,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ **kwargs
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.layerdrop = layerdrop
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.conv_glu_dim = conv_glu_dim
+ self.conv_dropout = conv_dropout
+ self.num_conv_layers = num_conv_layers
+ self.input_feat_per_channel = input_feat_per_channel
+ self.input_channels = input_channels
+ self.conv_channels = conv_channels
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # prevents config testing fail with exporting to json
+ self.conv_kernel = list(conv_kernel)
+ self.conv_stride = list(conv_stride)
+
+ if len(self.conv_kernel) != self.num_conv_layers:
+ raise ValueError(
+ "Configuration for convolutional module is incorrect. "
+ "It is required that `len(config.conv_kernel)` == `config.num_conv_layers` "
+ f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, "
+ f"`config.num_conv_layers = {self.num_conv_layers}`."
+ )
diff --git a/src/transformers/models/mctct/feature_extraction_mctct.py b/src/transformers/models/mctct/feature_extraction_mctct.py
new file mode 100644
index 00000000000000..573551bcf7780d
--- /dev/null
+++ b/src/transformers/models/mctct/feature_extraction_mctct.py
@@ -0,0 +1,356 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for M-CTC-T
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+import torchaudio
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...file_utils import PaddingStrategy, TensorType
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MCTCTFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a M-CTC-T feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods. This
+ code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to
+ this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an)
+ that takes the user step-by-step in the implementation.
+
+ Args:
+ feature_size (`int`, defaults to 80):
+ The feature dimension of the extracted features. This is the number of mel_frequency
+ sampling_rate (`int`, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
+ padding_value (`float`, defaults to 0.0):
+ The value that is used to fill the padding values.
+ hop_length (`int`, defaults to 10):
+ Number of audio samples between windows. Otherwise referred to as "shift" in many papers.
+ win_length (`int`, defaults to 25):
+ Number of ms per window
+ win_function (`str`, defaults to `"hamming_window"`):
+ Name for the window function used for windowing, must be accessible via `torch.{win_function}`
+ frame_signal_scale (`float`, defaults to 32768.0):
+ Constant multiplied in creating the frames before applying DFT.
+ preemphasis_coeff (`float`, defaults to 0.97):
+ Constant multiplied in applying Pre-emphasis before DFT.
+ mel_floor (`float` defaults to 1.0):
+ Minimum value of mel frequency banks.
+ normalize_means (`bool`, *optional*, defaults to `True`):
+ Whether or not to zero-mean normalize the extracted features.
+ normalize_vars (`bool`, *optional*, defaults to `True`):
+ Whether or not to unit-variance normalize the extracted features.
+ """
+
+ model_input_names = ["input_features", "attention_mask"]
+
+ def __init__(
+ self,
+ feature_size=80,
+ sampling_rate=16000,
+ padding_value=0.0,
+ hop_length=10,
+ win_length=25,
+ win_function="hamming_window",
+ frame_signal_scale=32768.0,
+ preemphasis_coeff=0.97,
+ mel_floor=1.0,
+ normalize_means=True,
+ normalize_vars=True,
+ return_attention_mask=False,
+ **kwargs
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.frame_signal_scale = frame_signal_scale
+ self.preemphasis_coeff = preemphasis_coeff
+ self.mel_floor = mel_floor
+ self.normalize_means = normalize_means
+ self.normalize_vars = normalize_vars
+ self.win_function = win_function
+ self.return_attention_mask = return_attention_mask
+
+ self.sample_size = win_length * sampling_rate // 1000
+ self.sample_stride = hop_length * sampling_rate // 1000
+
+ self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
+ self.n_freqs = (self.n_fft // 2) + 1
+
+ @staticmethod
+ def _num_frames_calc(in_size, frame_size, frame_stride):
+ return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
+
+ @staticmethod
+ def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
+ scale = frame_signal_scale
+ frames = np.zeros(n_frames * window_length)
+ for frame_idx in range(n_frames):
+ start = frame_idx * window_length
+ end = (frame_idx + 1) * window_length
+ wave_start = frame_idx * sample_stride
+ wave_end = frame_idx * sample_stride + window_length
+ frames[start:end] = scale * one_waveform[wave_start:wave_end]
+
+ return frames
+
+ @staticmethod
+ def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
+ if frames.size % window_length != 0:
+ raise ValueError(
+ f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
+ f" window_length={window_length}."
+ )
+
+ n_frames = frames.size // window_length
+ for frame_idx in range(n_frames, 0, -1):
+ start = (frame_idx - 1) * window_length
+ end = frame_idx * window_length - 1
+ frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
+ frames[start] *= 1 - preemphasis_coeff
+
+ @staticmethod
+ def _windowing(frames, window_length, window):
+ if frames.size % window_length != 0:
+ raise ValueError(
+ f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
+ f" window_length={window_length}."
+ )
+
+ shaped = frames.reshape(-1, window_length)
+ shaped = window * shaped
+ return shaped
+
+ @staticmethod
+ def _dft(frames, K, n_frames, n_samples, n_fft):
+ dft = np.zeros([n_frames, K])
+
+ for frame in range(n_frames):
+ begin = frame * n_samples
+
+ inwards_buffer = frames[begin : begin + n_samples]
+ inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
+ out = np.fft.rfft(inwards_buffer)
+
+ dft[frame] = np.abs(out[:K])
+
+ return dft
+
+ def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
+ """
+ Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
+ """
+ if self.win_function == "hamming_window":
+ window = torch.hamming_window(window_length=self.sample_size, periodic=False, alpha=0.54, beta=0.46)
+ else:
+ window = getattr(torch, self.win_function)()
+
+ window = window.numpy()
+
+ fbanks = torchaudio.functional.melscale_fbanks(
+ n_freqs=self.n_freqs,
+ f_min=0.0, # change this to zeros
+ f_max=self.sampling_rate / 2.0,
+ n_mels=self.feature_size,
+ sample_rate=self.sampling_rate,
+ )
+
+ fbanks = fbanks.numpy()
+
+ n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)
+
+ frames = self._frame_signal(
+ one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
+ )
+
+ self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
+
+ frames = self._windowing(frames, self.sample_size, window)
+
+ dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
+
+ # msfc_features = STFT * mel frequency banks.
+ msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
+
+ # clamp feature values then log scale, as implemented in flashlight
+ msfc_features = np.maximum(msfc_features, self.mel_floor)
+ msfc_features = np.log(msfc_features)
+
+ return msfc_features
+
+ def _normalize_one(self, x, input_length, padding_value):
+ # make sure we normalize float32 arrays
+ if self.normalize_means:
+ mean = x[:input_length].mean(axis=0)
+ x = np.subtract(x, mean)
+ if self.normalize_vars:
+ std = x[:input_length].std(axis=0)
+ x = np.divide(x, std)
+
+ if input_length < x.shape[0]:
+ x[input_length:] = padding_value
+
+ # make sure array is in float32
+ x = x.astype(np.float32)
+
+ return x
+
+ def normalize(
+ self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
+ ) -> List[np.ndarray]:
+ lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
+ return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)]
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
+ padding: Union[bool, str, PaddingStrategy] = False,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the
+ log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.
+
+ Args:
+ raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
+ of float values, a list of tensors, a list of numpy arrays or a list of list of float values.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ >= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ padding_value (`float`, defaults to 0.0):
+ """
+
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ is_batched = bool(
+ isinstance(raw_speech, (list, tuple))
+ and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
+ )
+
+ if is_batched:
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ raw_speech = raw_speech.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_speech = [raw_speech]
+
+ # extract fbank features
+ features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech]
+
+ # convert into correct format for padding
+ encoded_inputs = BatchFeature({"input_features": features})
+
+ padded_inputs = self.pad(
+ encoded_inputs,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=True,
+ **kwargs,
+ )
+ # make sure list is in array format
+ input_features = padded_inputs.get("input_features")
+ if isinstance(input_features[0], list):
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
+
+ attention_mask = padded_inputs.get("attention_mask")
+ if attention_mask is not None:
+ padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
+
+ if self.normalize_means or self.normalize_vars:
+ attention_mask = (
+ np.array(attention_mask, dtype=np.int32)
+ if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
+ and padding
+ else None
+ )
+ padded_inputs["input_features"] = self.normalize(
+ padded_inputs["input_features"], attention_mask=attention_mask
+ )
+
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
diff --git a/src/transformers/models/mctct/modeling_mctct.py b/src/transformers/models/mctct/modeling_mctct.py
new file mode 100755
index 00000000000000..25d368b7dc75d6
--- /dev/null
+++ b/src/transformers/models/mctct/modeling_mctct.py
@@ -0,0 +1,825 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch M-CTC-T model."""
+
+
+import math
+import random
+from typing import Optional
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+
+from ...activations import ACT2FN
+from ...deepspeed import is_deepspeed_zero3_enabled
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput
+from ...modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from ...utils import logging
+from .configuration_mctct import MCTCTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_HIDDEN_STATES_START_POSITION = 1
+
+_CONFIG_FOR_DOC = "MCTCTConfig"
+_PROCESSOR_FOR_DOC = "MCTCTProcessor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large"
+_EXPECTED_OUTPUT_SHAPE = [1, 195, 1536]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."'
+_CTC_EXPECTED_LOSS = 1885.65
+
+
+MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "speechbrain/m-ctc-t-large",
+ # See all M-CTC-T models at https://huggingface.co/models?filter=mctct
+]
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class MCTCTConv1dSubsampler(nn.Module):
+ """
+ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
+ via gated linear units (https://arxiv.org/abs/1911.08460)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.glu_dim = config.conv_glu_dim
+
+ self.dropout = nn.Dropout(config.conv_dropout)
+
+ self.num_layers = config.num_conv_layers
+ self.in_channels = config.input_feat_per_channel * config.input_channels
+
+ if self.num_layers > 1:
+ if config.conv_channels is None:
+ raise ValueError(
+ "Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution"
+ " layers."
+ )
+
+ self.mid_channels = config.conv_channels
+ else:
+ self.mid_channels = None
+
+ self.out_channels = config.hidden_size * 2 # considering GLU halving
+ self.kernel_size = config.conv_kernel
+ self.stride = config.conv_stride
+
+ # NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for
+ # multiple layers of convolutions, but not sure if this model definition should just restrict it
+ # to one layer. This becomes especially relevant when considering the padding like line 1 of forward().
+ self.conv_layers = nn.ModuleList(
+ nn.Conv1d(
+ self.in_channels if i == 0 else self.mid_channels[i],
+ self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels,
+ kernel_size=k,
+ stride=self.stride[i],
+ padding="valid",
+ )
+ for i, k in enumerate(self.kernel_size)
+ )
+
+ def forward(self, input_features):
+ # NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if
+ # there will be just one conv layer.
+ padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3)
+
+ input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0)
+ hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time
+ for conv in self.conv_layers:
+ hidden_states = conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame
+ return hidden_states
+
+
+class MCTCTEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.LayerNorm = MCTCTLayerNorm()
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
+ persistent=False,
+ )
+
+ def forward(
+ self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_features)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class MCTCTSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = config.attention_head_dim
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def reshape_fortran(self, x, shape):
+ if len(x.shape) > 0:
+ x = x.permute(*reversed(range(len(x.shape))))
+ return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
+
+ def relative_position_embedding_rotate(self, scores):
+ # NOTE: should re-evaluate whether this re-implementation was truly necessary
+ # or the reason why my complete re-haul worked was due to some other part
+ # of the code. Adding this and the reshape fortrain code seems very undesirable.
+ scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4]
+
+ batch, hidden_state, seq_len, heads = scores.shape
+
+ # e.g. [10, 1853, 14, 4]
+ scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1)
+
+ # e.g. [10, 25942, 1, 4]
+ scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads])
+
+ # e.g. [10, 25928, 1, 4]
+ scores = scores[:, : (seq_len + hidden_state - 1) * seq_len]
+
+ # e.g. [10, 1852, 14, 4]
+ scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads])
+
+ halfpoint = hidden_state // 2
+ scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4]
+
+ return scores.permute(0, 3, 1, 2)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ # relative key position embeddings
+ positional_embedding = self.distance_embedding.weight
+ relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3))
+
+ relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores)
+ attention_scores = attention_scores + relative_position_scores
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class MCTCTLayerNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.singleton_weight = nn.Parameter(torch.ones(1))
+ self.singleton_bias = nn.Parameter(torch.zeros(1))
+
+ def forward(self, hidden_states):
+ return (hidden_states * self.singleton_weight) + self.singleton_bias
+
+
+class MCTCTSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = MCTCTSelfAttention(config)
+ self.output = MCTCTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+
+class MCTCTIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class MCTCTOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTLayer(nn.Module):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__()
+
+ self.seq_len_dim = 1
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+
+ self.intermediate = MCTCTIntermediate(config)
+ self.attention = MCTCTAttention(config)
+ self.is_decoder = config.is_decoder
+ self.output = MCTCTOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class MCTCTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MCTCTConfig
+ base_model_prefix = "mctct"
+ main_input_name = "input_features"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, MCTCTLayerNorm):
+ module.singleton_weight.data.fill_(1.0)
+ module.singleton_bias.data.zero_()
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+ dilation = 1
+ for _, kernel_sz, stride in zip(
+ range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride
+ ):
+ padding = kernel_sz // 2
+ input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1
+ input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
+ # generate creates 3D attention mask, because of the shape of input_features
+ # convert it to 2D if thats the case
+ if len(attention_mask.shape) > 2:
+ attention_mask = attention_mask[:, :, -1]
+
+ # subsampled_lengths = attention_mask.sum(-1)
+ subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ bsz = attention_mask.size()[0]
+ attention_mask = torch.zeros(
+ (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (MCTCTEncoder)):
+ module.gradient_checkpointing = value
+
+
+MCTCT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MCTCT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MCTCTEncoder(MCTCTPreTrainedModel):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__(config)
+ self.hidden_dropout_prob = config.hidden_dropout_prob
+
+ self.layer_norm = MCTCTLayerNorm()
+ self.conv = MCTCTConv1dSubsampler(config)
+ self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_features,
+ attention_mask,
+ head_mask,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_features = self.layer_norm(input_features)
+
+ inputs_embeds = self.conv(input_features)
+
+ # subsample attention mask if necessary
+ if attention_mask is not None:
+ attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
+
+ hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, "
+ f"but it is for {head_mask.size()[0]}."
+ )
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+@add_start_docstrings(
+ "The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTModel(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = MCTCTEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_features is None:
+ raise ValueError("You have to specify input_features.")
+
+ encoder_outputs = self.encoder(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTForCTC(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mctct = MCTCTModel(config)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = config.hidden_size
+
+ self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ labels=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.mctct(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ logits = self.ctc_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask
+ if attention_mask is not None
+ else torch.ones(input_features.shape[:-1], dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
diff --git a/src/transformers/models/mctct/processing_mctct.py b/src/transformers/models/mctct/processing_mctct.py
new file mode 100644
index 00000000000000..0892f345928b11
--- /dev/null
+++ b/src/transformers/models/mctct/processing_mctct.py
@@ -0,0 +1,82 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Speech processor class for M-CTC-T
+"""
+from contextlib import contextmanager
+
+from ...processing_utils import ProcessorMixin
+
+
+class MCTCTProcessor(ProcessorMixin):
+ r"""
+ Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.
+
+ [`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the
+ [`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`MCTCTFeatureExtractor`):
+ An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`AutoTokenizer`):
+ An instance of [`AutoTokenizer`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "MCTCTFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+
+ def __call__(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
+ [`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
+ """
+ return self.current_processor(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def pad(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
+ [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
+ """
+ return self.current_processor.pad(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
+ """
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.feature_extractor
diff --git a/src/transformers/models/megatron_bert/__init__.py b/src/transformers/models/megatron_bert/__init__.py
index d49ab274e565e6..9075b898377a3e 100644
--- a/src/transformers/models/megatron_bert/__init__.py
+++ b/src/transformers/models/megatron_bert/__init__.py
@@ -17,14 +17,19 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_megatron_bert"] = [
"MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"MegatronBertForCausalLM",
@@ -42,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MegatronBertForCausalLM,
diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
index b64a0d41b939ec..371782c2976e9a 100755
--- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py
+++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
@@ -20,7 +20,7 @@
import os
import warnings
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -154,8 +154,13 @@ def __init__(self, config):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
- ):
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -212,7 +217,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -319,7 +324,7 @@ def __init__(self, config):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, residual):
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return residual + hidden_states
@@ -354,14 +359,14 @@ def prune_heads(self, heads):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
ln_outputs = self.ln(hidden_states)
self_outputs = self.self(
ln_outputs,
@@ -400,7 +405,7 @@ def __init__(self, config):
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return input_tensor + hidden_states
@@ -425,14 +430,14 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
@@ -455,7 +460,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -507,17 +513,17 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -873,20 +879,20 @@ class PreTrainedModel
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -940,7 +946,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1022,18 +1028,18 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- next_sentence_label=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ next_sentence_label: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MegatronBertForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1133,21 +1139,21 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1287,19 +1293,19 @@ def set_output_embeddings(self, new_embeddings):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1379,18 +1385,18 @@ def __init__(self, config):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
**kwargs
- ):
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
@@ -1421,7 +1427,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
@@ -1489,17 +1496,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1588,17 +1595,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1684,17 +1691,17 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1765,18 +1772,18 @@ def __init__(self, config):
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- start_positions=None,
- end_positions=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/mluke/__init__.py b/src/transformers/models/mluke/__init__.py
index acd6dff11f1955..b6582e35a9d0b1 100644
--- a/src/transformers/models/mluke/__init__.py
+++ b/src/transformers/models/mluke/__init__.py
@@ -18,17 +18,27 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available
_import_structure = {}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mluke"] = ["MLukeTokenizer"]
if TYPE_CHECKING:
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mluke import MLukeTokenizer
diff --git a/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
index c75a710cee2fa4..9d61c3bc8e272a 100644
--- a/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
@@ -153,7 +153,8 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
if not (outputs.entity_last_hidden_state.shape == expected_shape):
raise ValueError(
- f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
+ f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
+ f" {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py
index 1ddf472d56cc20..57272c391fb30c 100644
--- a/src/transformers/models/mluke/tokenization_mluke.py
+++ b/src/transformers/models/mluke/tokenization_mluke.py
@@ -18,6 +18,7 @@
import itertools
import json
import os
+from collections.abc import Mapping
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -341,7 +342,8 @@ def __init__(
self.max_entity_length = 2
else:
raise ValueError(
- f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only."
+ f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification',"
+ " 'entity_span_classification'] only."
)
self.max_mention_length = max_mention_length
@@ -706,7 +708,7 @@ def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spa
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
- "entity_spans should be given as a list of tuples " "containing the start and end character indices"
+ "entity_spans should be given as a list of tuples containing the start and end character indices"
)
if entities is not None:
@@ -1118,7 +1120,8 @@ def prepare_for_model(
if num_invalid_entities != 0:
logger.warning(
- f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens"
+ f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
+ " truncation of input tokens"
)
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
@@ -1143,7 +1146,7 @@ def prepare_for_model(
entity_position_ids = []
entity_start_positions = []
entity_end_positions = []
- for (token_spans, offset) in (
+ for token_spans, offset in (
(valid_entity_token_spans, entity_token_offset),
(valid_pair_entity_token_spans, pair_entity_token_offset),
):
@@ -1253,7 +1256,7 @@ def pad(
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
@@ -1293,7 +1296,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -1506,7 +1509,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
+ f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return out_vocab_file, entity_vocab_file
diff --git a/src/transformers/models/mmbt/__init__.py b/src/transformers/models/mmbt/__init__.py
index 763a256f1a20b6..d95a2cc8d84ae5 100644
--- a/src/transformers/models/mmbt/__init__.py
+++ b/src/transformers/models/mmbt/__init__.py
@@ -18,21 +18,29 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_mmbt": ["MMBTConfig"],
-}
+_import_structure = {"configuration_mmbt": ["MMBTConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mmbt"] = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]
if TYPE_CHECKING:
from .configuration_mmbt import MMBTConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
else:
diff --git a/src/transformers/models/mmbt/modeling_mmbt.py b/src/transformers/models/mmbt/modeling_mmbt.py
index 5e284c1b699657..8819dc4d5178c0 100644
--- a/src/transformers/models/mmbt/modeling_mmbt.py
+++ b/src/transformers/models/mmbt/modeling_mmbt.py
@@ -268,7 +268,7 @@ def forward(
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py
index 505dabe1879198..ae91c38bdfb356 100644
--- a/src/transformers/models/mobilebert/__init__.py
+++ b/src/transformers/models/mobilebert/__init__.py
@@ -18,18 +18,38 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"],
+ "configuration_mobilebert": [
+ "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "MobileBertConfig",
+ "MobileBertOnnxConfig",
+ ],
"tokenization_mobilebert": ["MobileBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mobilebert_fast"] = ["MobileBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mobilebert"] = [
"MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"MobileBertForMaskedLM",
@@ -45,7 +65,12 @@
"load_tf_weights_in_mobilebert",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mobilebert"] = [
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileBertForMaskedLM",
@@ -62,13 +87,27 @@
if TYPE_CHECKING:
- from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
+ from .configuration_mobilebert import (
+ MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ MobileBertConfig,
+ MobileBertOnnxConfig,
+ )
from .tokenization_mobilebert import MobileBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mobilebert_fast import MobileBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MobileBertForMaskedLM,
@@ -84,7 +123,12 @@
load_tf_weights_in_mobilebert,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py
index 27863235b3d7d8..73b8844ed763df 100644
--- a/src/transformers/models/mobilebert/configuration_mobilebert.py
+++ b/src/transformers/models/mobilebert/configuration_mobilebert.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" MobileBERT model configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -165,3 +168,20 @@ def __init__(
self.true_hidden_size = hidden_size
self.classifier_dropout = classifier_dropout
+
+
+# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert
+class MobileBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
index 5c03331eb3d9af..022a9d036cdb24 100644
--- a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
@@ -46,8 +46,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file,
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained MobileBERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained MobileBERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py
index 1a2156ed31d034..6bc306a6e05eb3 100644
--- a/src/transformers/models/mobilebert/modeling_mobilebert.py
+++ b/src/transformers/models/mobilebert/modeling_mobilebert.py
@@ -164,7 +164,7 @@ def __init__(self, feat_size, eps=None):
self.bias = nn.Parameter(torch.zeros(feat_size))
self.weight = nn.Parameter(torch.ones(feat_size))
- def forward(self, input_tensor):
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return input_tensor * self.weight + self.bias
@@ -194,7 +194,13 @@ def __init__(self, config):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -220,9 +226,9 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs
# dimensional output.
inputs_embeds = torch.cat(
[
- nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0),
+ nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
inputs_embeds,
- nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0),
+ nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
],
dim=2,
)
@@ -260,13 +266,13 @@ def transpose_for_scores(self, x):
def forward(
self,
- query_tensor,
- key_tensor,
- value_tensor,
- attention_mask=None,
- head_mask=None,
- output_attentions=None,
- ):
+ query_tensor: torch.Tensor,
+ key_tensor: torch.Tensor,
+ value_tensor: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ ) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(query_tensor)
mixed_key_layer = self.key(key_tensor)
mixed_value_layer = self.value(value_tensor)
@@ -306,7 +312,7 @@ def __init__(self, config):
if not self.use_bottleneck:
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, residual_tensor):
+ def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
if not self.use_bottleneck:
layer_outputs = self.dropout(layer_outputs)
@@ -341,14 +347,14 @@ def prune_heads(self, heads):
def forward(
self,
- query_tensor,
- key_tensor,
- value_tensor,
- layer_input,
- attention_mask=None,
- head_mask=None,
- output_attentions=None,
- ):
+ query_tensor: torch.Tensor,
+ key_tensor: torch.Tensor,
+ value_tensor: torch.Tensor,
+ layer_input: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ ) -> Tuple[torch.Tensor]:
self_outputs = self.self(
query_tensor,
key_tensor,
@@ -373,7 +379,7 @@ def __init__(self, config):
else:
self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
@@ -386,7 +392,7 @@ def __init__(self, config):
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, residual_tensor):
+ def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
layer_outputs = self.dropout(layer_outputs)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
@@ -404,7 +410,9 @@ def __init__(self, config):
else:
self.bottleneck = OutputBottleneck(config)
- def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2):
+ def forward(
+ self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
+ ) -> torch.Tensor:
layer_output = self.dense(intermediate_states)
if not self.use_bottleneck:
layer_output = self.dropout(layer_output)
@@ -421,7 +429,7 @@ def __init__(self, config):
self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
layer_input = self.dense(hidden_states)
layer_input = self.LayerNorm(layer_input)
return layer_input
@@ -436,7 +444,7 @@ def __init__(self, config):
if self.key_query_shared_bottleneck:
self.attention = BottleneckLayer(config)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
# This method can return three different tuples of values. These different values make use of bottlenecks,
# which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
# usage. These linear layer have weights that are learned during training.
@@ -469,7 +477,7 @@ def __init__(self, config):
self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states, residual_tensor):
+ def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs
@@ -481,7 +489,7 @@ def __init__(self, config):
self.intermediate = MobileBertIntermediate(config)
self.output = FFNOutput(config)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(hidden_states)
layer_outputs = self.output(intermediate_output, hidden_states)
return layer_outputs
@@ -503,11 +511,11 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=None,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ ) -> Tuple[torch.Tensor]:
if self.use_bottleneck:
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
else:
@@ -557,13 +565,13 @@ def __init__(self, config):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
@@ -599,7 +607,7 @@ def __init__(self, config):
if self.do_activate:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
@@ -621,7 +629,7 @@ def __init__(self, config):
self.transform_act_fn = config.hidden_act
self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
@@ -640,7 +648,7 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.decoder.bias
@@ -652,7 +660,7 @@ def __init__(self, config):
super().__init__()
self.predictions = MobileBertLMPredictionHead(config)
- def forward(self, sequence_output):
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
@@ -663,7 +671,7 @@ def __init__(self, config):
self.predictions = MobileBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, sequence_output, pooled_output):
+ def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
@@ -841,16 +849,16 @@ class PreTrainedModel
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_hidden_states=None,
- output_attentions=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -875,9 +883,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- attention_mask, input_shape, self.device
- )
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -945,18 +951,18 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Em
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- next_sentence_label=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ next_sentence_label: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, MobileBertForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1061,17 +1067,17 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Em
)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1117,7 +1123,7 @@ def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
- def forward(self, pooled_output):
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
@@ -1140,18 +1146,18 @@ def __init__(self, config):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
**kwargs,
- ):
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
@@ -1182,7 +1188,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
diff --git a/src/transformers/models/mpnet/__init__.py b/src/transformers/models/mpnet/__init__.py
index 54c2c7b8419a63..5b3bc0dbd37559 100644
--- a/src/transformers/models/mpnet/__init__.py
+++ b/src/transformers/models/mpnet/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_mpnet": ["MPNetTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_mpnet_fast"] = ["MPNetTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mpnet"] = [
"MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"MPNetForMaskedLM",
@@ -42,7 +59,12 @@
"MPNetPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mpnet"] = [
"TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMPNetEmbeddings",
@@ -61,10 +83,20 @@
from .configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
from .tokenization_mpnet import MPNetTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_mpnet_fast import MPNetTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mpnet import (
MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
MPNetForMaskedLM,
@@ -77,7 +109,12 @@
MPNetPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mpnet import (
TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMPNetEmbeddings,
diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py
index 89b68544a1efdb..e7977561fe2b1a 100644
--- a/src/transformers/models/mpnet/modeling_mpnet.py
+++ b/src/transformers/models/mpnet/modeling_mpnet.py
@@ -547,7 +547,7 @@ def forward(
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
diff --git a/src/transformers/models/mpnet/tokenization_mpnet.py b/src/transformers/models/mpnet/tokenization_mpnet.py
index f092e6a311a964..713a528d557a04 100644
--- a/src/transformers/models/mpnet/tokenization_mpnet.py
+++ b/src/transformers/models/mpnet/tokenization_mpnet.py
@@ -175,8 +175,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py
index dd576cb0b25cd0..66daf82b388df5 100644
--- a/src/transformers/models/mt5/__init__.py
+++ b/src/transformers/models/mt5/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -42,30 +43,58 @@
MT5TokenizerFast = T5TokenizerFast
-_import_structure = {
- "configuration_mt5": ["MT5Config"],
-}
+_import_structure = {"configuration_mt5": ["MT5Config"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
if TYPE_CHECKING:
from .configuration_mt5 import MT5Config
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
else:
diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py
index d6a343f77dbce3..ad0345f53189e9 100644
--- a/src/transformers/models/mt5/configuration_mt5.py
+++ b/src/transformers/models/mt5/configuration_mt5.py
@@ -117,6 +117,21 @@ def __init__(
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
@property
def hidden_size(self):
return self.d_model
diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py
index d45ea49645d395..0f35e7f9e41917 100644
--- a/src/transformers/models/mt5/modeling_flax_mt5.py
+++ b/src/transformers/models/mt5/modeling_flax_mt5.py
@@ -14,6 +14,8 @@
# limitations under the License.
""" Flax mT5 model."""
+import numpy as np
+
from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config
@@ -25,6 +27,19 @@
_TOKENIZER_FOR_DOC = "T5Tokenizer"
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = np.zeros_like(input_ids)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
class FlaxMT5Model(FlaxT5Model):
r"""
This class overrides [`FlaxT5Model`]. Please check the superclass for the appropriate documentation alongside usage
diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py
index 314198c69a9a41..8c19a63eded3cd 100644
--- a/src/transformers/models/mt5/modeling_mt5.py
+++ b/src/transformers/models/mt5/modeling_mt5.py
@@ -49,13 +49,13 @@ class MT5Model(T5Model):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
]
@@ -84,10 +84,10 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
@@ -112,8 +112,8 @@ class MT5EncoderModel(T5EncoderModel):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
diff --git a/src/transformers/models/nystromformer/__init__.py b/src/transformers/models/nystromformer/__init__.py
index d3df751dd4f6e6..a239e435f97be8 100644
--- a/src/transformers/models/nystromformer/__init__.py
+++ b/src/transformers/models/nystromformer/__init__.py
@@ -18,14 +18,19 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_nystromformer": ["NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "NystromformerConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_nystromformer"] = [
"NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"NystromformerForMaskedLM",
@@ -42,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_nystromformer import (
NYSTROMFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
NystromformerForMaskedLM,
diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py
index 636e5df108ce18..b5813af781b72f 100755
--- a/src/transformers/models/nystromformer/modeling_nystromformer.py
+++ b/src/transformers/models/nystromformer/modeling_nystromformer.py
@@ -624,7 +624,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
diff --git a/src/transformers/models/openai/__init__.py b/src/transformers/models/openai/__init__.py
index 3abba0b781bc5b..8aaaaa62a98967 100644
--- a/src/transformers/models/openai/__init__.py
+++ b/src/transformers/models/openai/__init__.py
@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +32,20 @@
"tokenization_openai": ["OpenAIGPTTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_openai_fast"] = ["OpenAIGPTTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_openai"] = [
"OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"OpenAIGPTDoubleHeadsModel",
@@ -40,7 +56,12 @@
"load_tf_weights_in_openai_gpt",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_openai"] = [
"TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFOpenAIGPTDoubleHeadsModel",
@@ -56,10 +77,20 @@
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .tokenization_openai import OpenAIGPTTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_openai_fast import OpenAIGPTTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
@@ -70,7 +101,12 @@
load_tf_weights_in_openai_gpt,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
index b57f2dd0339fcc..1b101aea0cc0de 100755
--- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
@@ -64,8 +64,10 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
"--openai_config_file",
default="",
type=str,
- help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
- "This specifies the model architecture.",
+ help=(
+ "An optional config json file corresponding to the pre-trained OpenAI model. \n"
+ "This specifies the model architecture."
+ ),
)
args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(
diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py
index 8ded535cef102e..f5136781378135 100644
--- a/src/transformers/models/openai/modeling_openai.py
+++ b/src/transformers/models/openai/modeling_openai.py
@@ -81,12 +81,14 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# Check that the token and position embeddings weight dimensions map those of the init parameters.
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
- f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
+ f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
+ f" {init_params[1].shape}"
)
if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
- f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
+ f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
+ f" {init_params[0].shape}"
)
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
@@ -674,7 +676,7 @@ def forward(
>>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-gpt")
>>> tokenizer.add_special_tokens(
... {"cls_token": "[CLS]"}
- >>> ) # Add a [CLS] to the vocabulary (we should train it also!)
+ ... ) # Add a [CLS] to the vocabulary (we should train it also!)
>>> model.resize_token_embeddings(len(tokenizer))
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
@@ -812,7 +814,7 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py
index 5215ad7c2ff6ee..528494836a3cbc 100644
--- a/src/transformers/models/openai/modeling_tf_openai.py
+++ b/src/transformers/models/openai/modeling_tf_openai.py
@@ -693,9 +693,9 @@ def call(
>>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()}
>>> inputs["mc_token_ids"] = tf.constant(
... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1]
- >>> )[
+ ... )[
... None, :
- >>> ] # Batch size 1
+ ... ] # Batch size 1
>>> outputs = model(inputs)
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
```"""
@@ -851,7 +851,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/openai/tokenization_openai.py b/src/transformers/models/openai/tokenization_openai.py
index ca21943a23594f..40bb824cd7186d 100644
--- a/src/transformers/models/openai/tokenization_openai.py
+++ b/src/transformers/models/openai/tokenization_openai.py
@@ -215,7 +215,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py
new file mode 100644
index 00000000000000..e35d07d1b012d2
--- /dev/null
+++ b/src/transformers/models/opt/__init__.py
@@ -0,0 +1,96 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {"configuration_opt": ["OPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OPTConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_opt"] = [
+ "OPT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "OPTForCausalLM",
+ "OPTModel",
+ "OPTPreTrainedModel",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_opt"] = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_opt"] = [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_opt import OPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPTConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/opt/configuration_opt.py b/src/transformers/models/opt/configuration_opt.py
new file mode 100644
index 00000000000000..eb7c8e0208d69e
--- /dev/null
+++ b/src/transformers/models/opt/configuration_opt.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" OPT model configuration"""
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+OPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/opt-125m": "https://huggingface.co/facebook/opt-125m/blob/main/config.json",
+ "facebook/opt-350m": "https://huggingface.co/facebook/opt-350m/blob/main/config.json",
+ "facebook/opt-1.3b": "https://huggingface.co/facebook/opt-1.3b/blob/main/config.json",
+ "facebook/opt-2.7b": "https://huggingface.co/facebook/opt-2.7b/blob/main/config.json",
+ "facebook/opt-6.7b": "https://huggingface.co/facebook/opt-6.7b/blob/main/config.json",
+ "facebook/opt-13b": "https://huggingface.co/facebook/opt-13b/blob/main/config.json",
+}
+
+
+class OPTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the OPT
+ [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50272):
+ Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`OPTModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ ffn_dim (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ do_layer_norm_before (`bool`, *optional*, defaults to `True`):
+ Whether to perform layer normalization before the attention block.
+ word_embed_proj_dim (`int`, *optional*):
+ `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to
+ `hidden_size`.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ layerdrop: (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
+ details.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+
+ Example:
+
+ ```python
+ >>> from transformers import OPTModel, OPTConfig
+
+ >>> # Initializing a OPT facebook/opt-large style configuration
+ >>> configuration = OPTConfig()
+
+ >>> # Initializing a model from the facebook/opt-large style configuration
+ >>> model = OPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "opt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50272,
+ hidden_size=768,
+ num_hidden_layers=12,
+ ffn_dim=3072,
+ max_position_embeddings=2048,
+ do_layer_norm_before=True,
+ word_embed_proj_dim=None,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ num_attention_heads=12,
+ activation_function="relu",
+ layerdrop=0.0,
+ init_std=0.02,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=2,
+ eos_token_id=2,
+ **kwargs
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.num_attention_heads = num_attention_heads
+ self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
+ self.ffn_dim = ffn_dim
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.do_layer_norm_before = do_layer_norm_before
diff --git a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 00000000000000..5992dc7e9a36b1
--- /dev/null
+++ b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert OPT checkpoint."""
+
+
+import argparse
+from pathlib import Path
+
+import torch
+
+from transformers import OPTConfig, OPTModel
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def load_checkpoint(checkpoint_path):
+ """Checkpoint path should end in model.pt"""
+ sd = torch.load(checkpoint_path, map_location="cpu")
+ if "model" in sd.keys():
+ sd = torch.load(checkpoint_path, map_location="cpu")["model"]
+
+ # pop unnecessary weights
+ keys_to_delete = [
+ "decoder.version",
+ "decoder.layer_norm.weight",
+ "decoder.layer_norm.bias",
+ "decoder.output_projection.weight",
+ ]
+ for key in keys_to_delete:
+ if key in sd:
+ sd.pop(key)
+
+ keys_to_rename = {
+ "decoder.project_in_dim.weight": "decoder.project_in.weight",
+ "decoder.project_out_dim.weight": "decoder.project_out.weight",
+ }
+ for old_key, new_key in keys_to_rename.items():
+ if old_key in sd:
+ sd[new_key] = sd.pop(old_key)
+
+ return sd
+
+
+@torch.no_grad()
+def convert_opt_checkpoint(checkpoint_path, pytorch_dump_folder_path, config=None):
+ """
+ Copy/paste/tweak model's weights to our BERT structure.
+ """
+ state_dict = load_checkpoint(checkpoint_path)
+
+ if config is not None:
+ config = OPTConfig.from_pretrained(config)
+ else:
+ config = OPTConfig()
+
+ model = OPTModel(config).half().eval()
+ model.load_state_dict(state_dict)
+
+ # Check results
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ model.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--fairseq_path",
+ type=str,
+ help=(
+ "path to fairseq checkpoint in correct format. You can find all checkpoints in the correct format here:"
+ " https://huggingface.co/models?other=opt_metasq"
+ ),
+ )
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--hf_config", default=None, type=str, help="Define HF config.")
+ args = parser.parse_args()
+ convert_opt_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, config=args.hf_config)
diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py
new file mode 100644
index 00000000000000..f84d56b0d8b1ea
--- /dev/null
+++ b/src/transformers/models/opt/modeling_flax_opt.py
@@ -0,0 +1,795 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Flax OPT model."""
+
+from functools import partial
+from typing import Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, logging
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT
+class FlaxOPTAttention(nn.Module):
+ config: OPTConfig
+ embed_dim: int
+ num_heads: int
+ dropout: float = 0.0
+ causal: bool = False
+ bias: bool = True
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self) -> None:
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ use_bias=self.bias,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+ self.out_proj = dense()
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states)
+ value_states = self.v_proj(key_value_states)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class FlaxOPTDecoderLayer(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.hidden_size
+ self.self_attn = FlaxOPTAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.num_attention_heads,
+ dropout=self.config.attention_dropout,
+ causal=True,
+ dtype=self.dtype,
+ )
+ self.do_layer_norm_before = self.config.do_layer_norm_before
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.fc1 = nn.Dense(
+ self.config.ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ init_cache: bool = False,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_states = (residual + hidden_states).reshape(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class FlaxOPTDecoderLayerCollection(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ self.layerdrop = self.config.layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ outputs = [hidden_states, all_hidden_states, all_self_attns]
+ return outputs
+
+
+class FlaxOPTLearnedPositionalEmbedding(nn.Embed):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def setup(self):
+ self.offset = 2
+ self.embedding = self.param(
+ "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype
+ )
+
+ def __call__(self, positions):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+
+ return super().__call__(positions + self.offset)
+
+
+class FlaxOPTDecoder(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ offset: int = 2
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.hidden_size
+ self.padding_idx = self.config.pad_token_id
+ self.max_target_positions = self.config.max_position_embeddings
+
+ self.embed_tokens = nn.Embed(
+ self.config.vocab_size,
+ self.config.word_embed_proj_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
+ self.config.max_position_embeddings,
+ embed_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ if self.config.word_embed_proj_dim != self.config.hidden_size:
+ self.project_in = nn.Dense(self.config.hidden_size, use_bias=False)
+ self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids)
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ positions = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + positions
+
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_state, all_hidden_states, attentions = self.layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if self.project_out is not None:
+ hidden_state = self.project_out(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_state,)
+
+ outputs = [hidden_state, all_hidden_states, attentions]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=all_hidden_states,
+ attentions=attentions,
+ )
+
+
+class FlaxOPTPreTrainedModel(FlaxPreTrainedModel):
+ config_class = OPTConfig
+ base_model_prefix: str = "model"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: OPTConfig,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ position_ids,
+ return_dict=False,
+ )
+
+ random_params = module_init_outputs["params"]
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ params: dict = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ dropout_rng: PRNGKey = None,
+ deterministic: bool = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ if position_ids is None:
+ position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxOPTAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxOPTModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype)
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ init_cache=False,
+ ):
+
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT
+class FlaxOPTModel(FlaxOPTPreTrainedModel):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ module_class = FlaxOPTModule
+
+
+append_call_sample_docstring(
+ FlaxOPTModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
+)
+
+
+@add_start_docstrings(
+ "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLMModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.model = FlaxOPTModule(config=self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=lm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for
+ autoregressive tasks.
+ """,
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
+ module_class = FlaxOPTForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxOPTForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxBaseModelOutput,
+ _CONFIG_FOR_DOC,
+)
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
new file mode 100644
index 00000000000000..6db58a82d61a33
--- /dev/null
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -0,0 +1,963 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch OPT model."""
+import random
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+
+OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/opt-125m",
+ "facebook/opt-350m",
+ "facebook/opt-1.3b",
+ "facebook/opt-2.7b",
+ "facebook/opt-6.7b",
+ "facebook/opt-13b",
+ "facebook/opt-30b",
+ # See all OPT models at https://huggingface.co/models?filter=opt
+]
+
+
+def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
+ mask_cond = torch.arange(mask.size(-1))
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class OPTLearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ attention_mask = attention_mask.long()
+
+ # create positions depending on attention_mask
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
+
+ # cut positions if `past_key_values_length` is > 0
+ positions = positions[:, past_key_values_length:]
+
+ return super().forward(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
+class OPTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class OPTDecoderLayer(nn.Module):
+ def __init__(self, config: OPTConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = OPTAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.do_layer_norm_before = config.do_layer_norm_before
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ hidden_states = (residual + hidden_states).view(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`OPTConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class OPTPreTrainedModel(PreTrainedModel):
+ config_class = OPTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["OPTDecoderLayer"]
+ _keys_to_ignore_on_load_unexpected = [r"decoder.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (OPTDecoder)):
+ module.gradient_checkpointing = value
+
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class OPTDecoder(OPTPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
+
+ Args:
+ config: OPTConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: OPTConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
+ else:
+ self.project_out = None
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
+ else:
+ self.project_in = None
+
+ self.layer_norm = None
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
+ ).to(inputs_embeds.device)
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class OPTModel(OPTPreTrainedModel):
+ def __init__(self, config: OPTConfig):
+ super().__init__(config)
+ self.decoder = OPTDecoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+class OPTForCausalLM(OPTPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OPTModel(config)
+
+ # the lm_head weight is automatically tied to the embed tokens weight
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
+
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0]).contiguous()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py
new file mode 100644
index 00000000000000..0c3de0ce2069f0
--- /dev/null
+++ b/src/transformers/models/opt/modeling_tf_opt.py
@@ -0,0 +1,1024 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 OPT model."""
+
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
+
+# Public API
+from ...modeling_tf_utils import (
+ DUMMY_INPUTS,
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSharedEmbeddings,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+LARGE_NEGATIVE = -1e8
+
+
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+ mask_cond = tf.range(shape_list(mask)[-1])
+
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+ if past_key_values_length > 0:
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
+
+ def call(self, attention_mask, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ # create positions depending on attention_mask
+ positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1
+
+ # cut positions if `past_key_values_length` is > 0
+ positions = positions[:, past_key_values_length:]
+
+ return super().call(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT
+class TFOPTAttention(tf.keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = tf.keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
+ attention_mask: Optional[tf.Tensor] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ training: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+
+class TFOPTDecoderLayer(tf.keras.layers.Layer):
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.do_layer_norm_before = config.do_layer_norm_before
+ self.embed_dim = config.hidden_size
+ self.self_attn = TFOPTAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ name="self_attn",
+ is_decoder=True,
+ )
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+
+ self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.fc1 = tf.keras.layers.Dense(config.ffn_dim, name="fc1")
+ self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ training: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`tf.Tensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size
+ `(decoder_attention_heads,)`
+ past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return (hidden_states, self_attn_weights, present_key_value)
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+ If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
+ first positional argument :
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+
+
+ Args:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class TFOPTPreTrainedModel(TFPreTrainedModel):
+ """
+ TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel
+
+ Args:
+ config: OPTConfig
+ """
+
+ config_class = OPTConfig
+ base_model_prefix = "model"
+
+ @property
+ def dummy_inputs(self):
+ pad_token = 1
+ input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
+ dummy_inputs = {
+ "attention_mask": tf.math.not_equal(input_ids, pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+ @tf.function(
+ input_signature=[
+ {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+
+ return self.serving_output(output)
+
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFOPTDecoder(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.layerdrop = config.layerdrop
+ num_embeddings = config.max_position_embeddings
+ self.embed_tokens = TFSharedEmbeddings(
+ config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens"
+ )
+ self.embed_positions = TFOPTLearnedPositionalEmbedding(
+ num_embeddings,
+ config.hidden_size,
+ name="embed_positions",
+ )
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
+ self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens.vocab_size = new_embeddings.shape[0]
+ self.embed_tokens.weight = new_embeddings
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
+ # create causal mask
+ # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+ return combined_attention_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+ decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
+
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ present_key_values = () if use_cache else None
+
+ # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
+ if attn_mask is not None and tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_mask)[0],
+ len(self.layers),
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ hidden_states, layer_self_attn, present_key_value = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ past_key_value=past_key_value,
+ )
+
+ if use_cache:
+ present_key_values += (present_key_value,)
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ else:
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@keras_serializable
+class TFOPTMainLayer(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.decoder = TFOPTDecoder(config, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.decoder(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare TF OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTModel(TFOPTPreTrainedModel):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.model.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=output.last_hidden_state,
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ )
+
+
+@add_start_docstrings(
+ """
+ The OPT Model transformer with a language modeling head on top.
+ """,
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_output_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ attention_mask = kwargs.get("attention_mask", None)
+
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ inputs = tf.expand_dims(inputs[:, -1], -1)
+
+ return {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @unpack_inputs
+ @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer, TFOPTForCausalLM
+
+ >>> model = TFOPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="tf")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ logits = self.model.decoder.embed_tokens(outputs[0], mode="linear")
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFCausalLMOutputWithPast(
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ loss=output.loss,
+ logits=output.logits,
+ )
diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py
index 4d01c31c6df234..ca04afeeb1a078 100644
--- a/src/transformers/models/pegasus/__init__.py
+++ b/src/transformers/models/pegasus/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"],
-}
+_import_structure = {"configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_pegasus"] = ["PegasusTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_pegasus_fast"] = ["PegasusTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_pegasus"] = [
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
"PegasusForCausalLM",
@@ -46,14 +60,24 @@
"PegasusPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_pegasus"] = [
"FlaxPegasusForConditionalGeneration",
"FlaxPegasusModel",
@@ -64,13 +88,28 @@
if TYPE_CHECKING:
from .configuration_pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_pegasus import PegasusTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_pegasus_fast import PegasusTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_pegasus import (
PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
PegasusForCausalLM,
@@ -79,10 +118,20 @@
PegasusPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_pegasus import (
FlaxPegasusForConditionalGeneration,
FlaxPegasusModel,
diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py
index 81276dcd2adc72..303d0055716ca6 100644
--- a/src/transformers/models/pegasus/modeling_flax_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -544,7 +544,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 2f79fa93fe5a3e..25a8676d6f1e4a 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -102,7 +102,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -775,7 +778,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -872,7 +876,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1043,7 +1047,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -1285,10 +1290,10 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
- r"embed_positions\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
+ r"embed_positions.weight",
]
def __init__(self, config: PegasusConfig):
diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py
index be2539b3a9100d..2c5696f94d36ea 100644
--- a/src/transformers/models/pegasus/modeling_tf_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py
@@ -268,7 +268,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -278,7 +281,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -294,7 +300,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -311,7 +320,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -787,7 +799,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
# encoder layers
@@ -989,7 +1004,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1317,7 +1335,7 @@ def call(
if labels is not None:
labels = tf.where(
labels == self.config.pad_token_id,
- tf.fill(shape_list(labels), -100),
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
labels,
)
use_cache = False
diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py
index a6a9167e66deff..b4d1cdc19804ce 100644
--- a/src/transformers/models/pegasus/tokenization_pegasus.py
+++ b/src/transformers/models/pegasus/tokenization_pegasus.py
@@ -119,7 +119,8 @@ def __init__(
if additional_special_tokens is not None:
if not isinstance(additional_special_tokens, list):
raise TypeError(
- f"additional_special_tokens should be of type {type(list)}, but is {type(additional_special_tokens)}"
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
)
additional_special_tokens_extended = (
@@ -134,7 +135,8 @@ def __init__(
if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
raise ValueError(
- f"Please make sure that the provided additional_special_tokens do not contain an incorrectly shifted list of tokens. Found {additional_special_tokens_extended}."
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
)
additional_special_tokens = additional_special_tokens_extended
else:
diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py
index 14399988f0fa29..22c6018385f6d0 100644
--- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py
+++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py
@@ -115,7 +115,8 @@ def __init__(
if additional_special_tokens is not None:
if not isinstance(additional_special_tokens, list):
raise TypeError(
- f"additional_special_tokens should be of type {type(list)}, but is {type(additional_special_tokens)}"
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
)
additional_special_tokens_extended = (
@@ -130,7 +131,8 @@ def __init__(
if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
raise ValueError(
- f"Please make sure that the provided additional_special_tokens do not contain an incorrectly shifted list of tokens. Found {additional_special_tokens_extended}."
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
)
additional_special_tokens = additional_special_tokens_extended
else:
@@ -158,7 +160,8 @@ def _special_token_mask(self, seq):
if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
raise ValueError(
- f"There should be 3 special tokens: mask_token, pad_token, and eos_token + {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
+ "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
+ f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
)
return [1 if x in all_special_ids else 0 for x in seq]
diff --git a/src/transformers/models/perceiver/__init__.py b/src/transformers/models/perceiver/__init__.py
index b2081830643467..107c62f2eb8ad5 100644
--- a/src/transformers/models/perceiver/__init__.py
+++ b/src/transformers/models/perceiver/__init__.py
@@ -17,18 +17,34 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
_import_structure = {
- "configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig"],
+ "configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverOnnxConfig"],
"tokenization_perceiver": ["PerceiverTokenizer"],
}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_perceiver"] = ["PerceiverFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_perceiver"] = [
"PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PerceiverForImageClassificationConvProcessing",
@@ -45,13 +61,23 @@
if TYPE_CHECKING:
- from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig
+ from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverOnnxConfig
from .tokenization_perceiver import PerceiverTokenizer
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_perceiver import PerceiverFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_perceiver import (
PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST,
PerceiverForImageClassificationConvProcessing,
diff --git a/src/transformers/models/perceiver/configuration_perceiver.py b/src/transformers/models/perceiver/configuration_perceiver.py
index fdf1f01243505f..0c97974441c53d 100644
--- a/src/transformers/models/perceiver/configuration_perceiver.py
+++ b/src/transformers/models/perceiver/configuration_perceiver.py
@@ -14,8 +14,15 @@
# limitations under the License.
""" Perceiver model configuration"""
+from collections import OrderedDict
+from typing import Any, Mapping, Optional, Union
+
from ...configuration_utils import PretrainedConfig
-from ...utils import logging
+from ...feature_extraction_utils import FeatureExtractionMixin
+from ...onnx import OnnxConfig
+from ...onnx.utils import compute_effective_axis_dimension
+from ...tokenization_utils_base import PreTrainedTokenizerBase
+from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
@@ -172,3 +179,63 @@ def __init__(
self.audio_samples_per_frame = audio_samples_per_frame
self.samples_per_patch = samples_per_patch
self.output_shape = output_shape
+
+
+class PerceiverOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("inputs", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+ def generate_dummy_inputs(
+ self,
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
+ batch_size: int = -1,
+ seq_length: int = -1,
+ num_choices: int = -1,
+ is_pair: bool = False,
+ framework: Optional[TensorType] = None,
+ num_channels: int = 3,
+ image_width: int = 40,
+ image_height: int = 40,
+ ) -> Mapping[str, Any]:
+ # copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified
+
+ if isinstance(preprocessor, PreTrainedTokenizerBase):
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(
+ batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
+ )
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
+ token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
+ seq_length = compute_effective_axis_dimension(
+ seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
+ )
+ # Generate dummy inputs according to compute batch and sequence
+ dummy_input = [" ".join(["a"]) * seq_length] * batch_size
+ inputs = dict(preprocessor(dummy_input, return_tensors=framework))
+ inputs["inputs"] = inputs.pop("input_ids")
+ return inputs
+ elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+ batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
+ dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
+ inputs = dict(preprocessor(images=dummy_input, return_tensors=framework))
+ inputs["inputs"] = inputs.pop("pixel_values")
+ return inputs
+ else:
+ raise ValueError(
+ "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
+ )
diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py
index b8acef6c226c47..364bc67c8dc3a6 100755
--- a/src/transformers/models/perceiver/modeling_perceiver.py
+++ b/src/transformers/models/perceiver/modeling_perceiver.py
@@ -19,7 +19,7 @@
from dataclasses import dataclass
from functools import reduce
from operator import __add__
-from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
import torch
@@ -177,7 +177,7 @@ def __init__(self, config):
super().__init__()
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
- def forward(self, batch_size):
+ def forward(self, batch_size: int):
return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
@@ -232,13 +232,13 @@ def transpose_for_scores(self, x, channels_per_head):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- inputs=None,
- inputs_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs: Optional[torch.FloatTensor] = None,
+ inputs_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
hidden_states = self.layernorm1(hidden_states)
inputs = self.layernorm2(inputs)
@@ -301,7 +301,7 @@ def __init__(self, config, input_channels, output_channels):
super().__init__()
self.dense = nn.Linear(input_channels, output_channels)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
return hidden_states
@@ -377,13 +377,13 @@ def prune_heads(self, heads):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- inputs=None,
- inputs_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs: Optional[torch.FloatTensor] = None,
+ inputs_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
@@ -418,7 +418,7 @@ def __init__(self, config, input_size, widening_factor):
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(widening_factor * input_size, input_size)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense2(hidden_states)
@@ -456,13 +456,13 @@ def __init__(
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- inputs=None,
- inputs_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs: Optional[torch.FloatTensor] = None,
+ inputs_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
attention_outputs = self.attention(
hidden_states,
attention_mask,
@@ -543,15 +543,15 @@ def __init__(self, config, kv_dim=None):
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- inputs=None,
- inputs_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs: Optional[torch.FloatTensor] = None,
+ inputs_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
@@ -754,14 +754,14 @@ class PreTrainedModel
@replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- inputs,
- attention_mask=None,
- subsampled_output_points=None,
- head_mask=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ inputs: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, PerceiverModelOutput]:
r"""
Returns:
@@ -864,8 +864,8 @@ def forward(
inputs_without_pos = None
if inputs.size()[-1] != self.config.d_model:
raise ValueError(
- f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
- "Make sure to set config.d_model appropriately."
+ f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
+ f" {self.config.d_model}. Make sure to set config.d_model appropriately."
)
batch_size, seq_length, _ = inputs.size()
@@ -1871,7 +1871,7 @@ def forward(
self,
inputs: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
- subsampled_output_points: Optional[Dict[str, torch.tensor]] = None,
+ subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@@ -2020,7 +2020,9 @@ def __init__(self, config):
def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
return None
- def forward(self, query, z, query_mask=None):
+ def forward(
+ self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
+ ) -> torch.FloatTensor:
# (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
z = torch.mean(z, dim=1)
# (batch_size, d_latents) -> (batch_size, config.num_labels)
@@ -2044,11 +2046,11 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none".
output_index_dims (`int`, *optional*):
The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
- num_channels (`int`, *optional*):
+ num_channels (`int`, *optional*, defaults to 128):
The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
qk_channels (`int`, *optional*):
The number of channels of the queries and keys in the cross-attention layer.
- v_channels (`int`, *optional*, defaults to 128):
+ v_channels (`int`, *optional*):
The number of channels of the values in the cross-attention layer.
num_heads (`int`, *optional*, defaults to 1):
The number of attention heads in the cross-attention layer.
@@ -2066,23 +2068,23 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
def __init__(
self,
- config,
- output_num_channels,
- position_encoding_type="trainable",
+ config: PerceiverConfig,
+ output_num_channels: int,
+ position_encoding_type: Optional[str] = "trainable",
# The following 2 arguments are ignored if position_encoding_type == 'none':
- output_index_dims=None,
- num_channels=128,
- subsampled_index_dims=None,
- qk_channels=None,
- v_channels=None,
- num_heads=1,
- widening_factor=1,
- use_query_residual=False,
- concat_preprocessed_input=False,
- final_project=True,
- position_encoding_only=False,
+ output_index_dims: Optional[int] = None,
+ num_channels: Optional[int] = 128,
+ subsampled_index_dims: Optional[int] = None,
+ qk_channels: Optional[int] = None,
+ v_channels: Optional[int] = None,
+ num_heads: Optional[int] = 1,
+ widening_factor: Optional[int] = 1,
+ use_query_residual: Optional[bool] = False,
+ concat_preprocessed_input: Optional[bool] = False,
+ final_project: Optional[bool] = True,
+ position_encoding_only: Optional[bool] = False,
**position_encoding_kwargs,
- ):
+ ) -> None:
super().__init__()
self.output_num_channels = output_num_channels
@@ -2183,7 +2185,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su
return pos_emb
- def forward(self, query, z, query_mask=None, output_attentions=False):
+ def forward(
+ self,
+ query: torch.Tensor,
+ z: torch.FloatTensor,
+ query_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> PerceiverDecoderOutput:
# Cross-attention decoding.
# key, value: B x N x K; query: B x M x K
# Attention maps -> B x N x M
@@ -2239,7 +2247,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
)
- def forward(self, query, z, query_mask=None, output_attentions=False):
+ def forward(
+ self,
+ query: torch.Tensor,
+ z: torch.FloatTensor,
+ query_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
# B x 1 x num_classes -> B x num_classes
@@ -2268,7 +2282,13 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su
raise ValueError("FlowDecoder doesn't support subsampling yet.")
return inputs
- def forward(self, query, z, query_mask=None, output_attentions=False):
+ def forward(
+ self,
+ query: torch.Tensor,
+ z: torch.FloatTensor,
+ query_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
preds = decoder_outputs.logits
# Output flow and rescale.
@@ -2291,7 +2311,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none".
"""
- def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs):
+ def __init__(
+ self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs
+ ) -> None:
super().__init__()
if len(output_shape) != 4: # B, T, H, W
raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
@@ -2318,7 +2340,9 @@ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, su
subsampled_points=subsampled_points,
)
- def forward(self, query, z, query_mask=None):
+ def forward(
+ self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
+ ) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z)
logits = decoder_outputs.logits
@@ -2378,14 +2402,14 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
def __init__(
self,
- config,
- modalities,
- num_outputs,
- output_num_channels,
- min_padding_size=2,
- subsampled_index_dims=None,
+ config: PerceiverConfig,
+ modalities: Dict[str, PerceiverAbstractDecoder],
+ num_outputs: int,
+ output_num_channels: int,
+ min_padding_size: Optional[int] = 2,
+ subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None,
**decoder_kwargs
- ):
+ ) -> None:
super().__init__()
self.modalities = nn.ModuleDict(modalities)
self.subsampled_index_dims = subsampled_index_dims
@@ -2447,7 +2471,13 @@ def embed(modality, x):
[embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
)
- def forward(self, query, z, query_mask=None, output_attentions=False):
+ def forward(
+ self,
+ query: torch.Tensor,
+ z: torch.FloatTensor,
+ query_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> torch.Tensor:
# B x 1 x num_classes -> B x num_classes
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
@@ -2680,7 +2710,7 @@ def num_dimensions(self) -> int:
def output_size(self, *args, **kwargs) -> int:
return self._num_channels
- def forward(self, batch_size):
+ def forward(self, batch_size: int) -> torch.Tensor:
position_embeddings = self.position_embeddings
if batch_size is not None:
@@ -2705,7 +2735,9 @@ def _check_or_build_spatial_positions(pos, index_dims, batch_size):
"""
if pos is None:
pos = build_linear_positions(index_dims)
- pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape)
+ # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
+ # but `torch.broadcast_to` cannot be converted to ONNX
+ pos = pos[None].expand((batch_size,) + pos.shape)
pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
else:
# Just a warning label: you probably don't want your spatial features to
@@ -2741,7 +2773,9 @@ def output_size(self):
return encoding_size
- def forward(self, index_dims, batch_size, device, pos=None):
+ def forward(
+ self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None
+ ) -> torch.FloatTensor:
pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
fourier_pos_enc = generate_fourier_features(
pos,
@@ -2771,7 +2805,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
Model configuration.
"""
- def __init__(self, config):
+ def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config = config
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
@@ -2781,7 +2815,7 @@ def __init__(self, config):
def num_channels(self) -> int:
return self.config.d_model
- def forward(self, inputs):
+ def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor:
embeddings = self.embeddings(inputs)
seq_length = inputs.shape[1]
@@ -2800,15 +2834,16 @@ class PerceiverEmbeddingDecoder(nn.Module):
Model configuration.
"""
- def __init__(self, config):
+ def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.bias = nn.Parameter(torch.zeros(self.vocab_size))
- def forward(self, hidden_states, embedding_layer):
+ def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = hidden_states.shape
- output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
+ # Flatten batch dim
+ output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
output = output + self.bias
return output.reshape([batch_size, seq_len, self.vocab_size])
@@ -2859,7 +2894,7 @@ class PerceiverClassificationPostprocessor(nn.Module):
Number of channels in the input.
"""
- def __init__(self, config, in_channels):
+ def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
super().__init__()
self.classifier = nn.Linear(in_channels, config.num_labels)
@@ -2881,7 +2916,7 @@ class PerceiverAudioPostprocessor(nn.Module):
Postprocessor type to use. Currently, only "patches" is supported.
"""
- def __init__(self, config, in_channels, postproc_type: str = "patches"):
+ def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
super().__init__()
if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels'
@@ -2908,7 +2943,7 @@ class PerceiverProjectionPostprocessor(nn.Module):
Number of channels in the output.
"""
- def __init__(self, in_channels, out_channels):
+ def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.classifier = nn.Linear(in_channels, out_channels)
@@ -3134,9 +3169,9 @@ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, netw
if self.prep_type != "patches":
# move channels to last dimension, as the _build_network_inputs method below expects this
if inputs.ndim == 4:
- inputs = torch.moveaxis(inputs, 1, -1)
+ inputs = torch.permute(inputs, (0, 2, 3, 1))
elif inputs.ndim == 5:
- inputs = torch.moveaxis(inputs, 2, -1)
+ inputs = torch.permute(inputs, (0, 1, 3, 4, 2))
else:
raise ValueError("Unsupported data format for conv1x1.")
@@ -3155,7 +3190,7 @@ class PerceiverOneHotPreprocessor(AbstractPreprocessor):
Model configuration.
"""
- def __init__(self, config):
+ def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config: PerceiverConfig = config
diff --git a/src/transformers/models/phobert/__init__.py b/src/transformers/models/phobert/__init__.py
index 0f226f537aa9a0..0d9a6f4cea1a3c 100644
--- a/src/transformers/models/phobert/__init__.py
+++ b/src/transformers/models/phobert/__init__.py
@@ -21,9 +21,7 @@
from ...utils import _LazyModule
-_import_structure = {
- "tokenization_phobert": ["PhobertTokenizer"],
-}
+_import_structure = {"tokenization_phobert": ["PhobertTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/plbart/__init__.py b/src/transformers/models/plbart/__init__.py
index 676feeb39cfabd..06204a8901e932 100644
--- a/src/transformers/models/plbart/__init__.py
+++ b/src/transformers/models/plbart/__init__.py
@@ -17,17 +17,31 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"],
-}
+_import_structure = {"configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_plbart"] = ["PLBartTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_plbart"] = [
"PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"PLBartForCausalLM",
@@ -41,10 +55,20 @@
if TYPE_CHECKING:
from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_plbart import PLBartTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_plbart import (
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
PLBartForCausalLM,
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index 97e3ec680cbf22..7ca17146f3c5f7 100755
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -116,7 +116,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
@@ -233,7 +233,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -249,7 +250,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -270,7 +272,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -784,7 +787,8 @@ def forward(
if head_mask is not None:
if head_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, encoder_layer in enumerate(self.layers):
@@ -879,7 +883,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1022,7 +1026,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1230,9 +1235,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
- r"encoder\.version",
- r"decoder\.version",
- r"lm_head\.weight",
+ r"encoder.version",
+ r"decoder.version",
+ r"lm_head.weight",
]
def __init__(self, config: PLBartConfig):
diff --git a/src/transformers/models/plbart/tokenization_plbart.py b/src/transformers/models/plbart/tokenization_plbart.py
index 4c302e8b62cead..4a3ee1cdcd11bc 100644
--- a/src/transformers/models/plbart/tokenization_plbart.py
+++ b/src/transformers/models/plbart/tokenization_plbart.py
@@ -33,19 +33,41 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-c-cpp-defect-detection": "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-c-cpp-defect-detection": (
+ "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model"
+ ),
"uclanlp/plbart-cs-java": "https://huggingface.co/uclanlp/plbart-cs-java/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-en_XX-java": "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-go-en_XX": "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-java-clone-detection": "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-en_XX-java": (
+ "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-go-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-java-clone-detection": (
+ "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model"
+ ),
"uclanlp/plbart-java-cs": "https://huggingface.co/uclanlp/plbart-java-cs/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-java-en_XX": "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-javascript-en_XX": "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-php-en_XX": "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-python-en_XX": "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-refine-java-medium": "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-refine-java-small": "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model",
- "uclanlp/plbart-ruby-en_XX": "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model",
+ "uclanlp/plbart-java-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-javascript-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-php-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-python-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-refine-java-medium": (
+ "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-refine-java-small": (
+ "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model"
+ ),
+ "uclanlp/plbart-ruby-en_XX": (
+ "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py
index 799752067fdaa5..7cb5e4acacb935 100644
--- a/src/transformers/models/poolformer/__init__.py
+++ b/src/transformers/models/poolformer/__init__.py
@@ -18,17 +18,25 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"],
-}
+_import_structure = {"configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_poolformer"] = ["PoolFormerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_poolformer"] = [
"POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"PoolFormerForImageClassification",
@@ -40,10 +48,20 @@
if TYPE_CHECKING:
from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_poolformer import PoolFormerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_poolformer import (
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
PoolFormerForImageClassification,
diff --git a/src/transformers/models/prophetnet/__init__.py b/src/transformers/models/prophetnet/__init__.py
index be4baf4a16f1c0..b739fb9f5d5a27 100644
--- a/src/transformers/models/prophetnet/__init__.py
+++ b/src/transformers/models/prophetnet/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_prophetnet": ["ProphetNetTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_prophetnet"] = [
"PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ProphetNetDecoder",
@@ -42,7 +47,12 @@
from .configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from .tokenization_prophetnet import ProphetNetTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,
diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py
index 9a6574c84d2be7..40f5939d99bc7d 100644
--- a/src/transformers/models/prophetnet/configuration_prophetnet.py
+++ b/src/transformers/models/prophetnet/configuration_prophetnet.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" ProphetNet model configuration"""
+from typing import Callable, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@@ -22,7 +23,9 @@
logger = logging.get_logger(__name__)
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/prophetnet-large-uncased": "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
+ "microsoft/prophetnet-large-uncased": (
+ "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json"
+ ),
}
@@ -103,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig):
def __init__(
self,
- activation_dropout=0.1,
- activation_function="gelu",
- vocab_size=30522,
- hidden_size=1024,
- encoder_ffn_dim=4096,
- num_encoder_layers=12,
- num_encoder_attention_heads=16,
- decoder_ffn_dim=4096,
- num_decoder_layers=12,
- num_decoder_attention_heads=16,
- attention_dropout=0.1,
- dropout=0.1,
- max_position_embeddings=512,
- init_std=0.02,
- is_encoder_decoder=True,
- add_cross_attention=True,
- decoder_start_token_id=0,
- ngram=2,
- num_buckets=32,
- relative_max_distance=128,
- disable_ngram_loss=False,
- eps=0.0,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
+ activation_dropout: Optional[float] = 0.1,
+ activation_function: Optional[Union[str, Callable]] = "gelu",
+ vocab_size: Optional[int] = 30522,
+ hidden_size: Optional[int] = 1024,
+ encoder_ffn_dim: Optional[int] = 4096,
+ num_encoder_layers: Optional[int] = 12,
+ num_encoder_attention_heads: Optional[int] = 16,
+ decoder_ffn_dim: Optional[int] = 4096,
+ num_decoder_layers: Optional[int] = 12,
+ num_decoder_attention_heads: Optional[int] = 16,
+ attention_dropout: Optional[float] = 0.1,
+ dropout: Optional[float] = 0.1,
+ max_position_embeddings: Optional[int] = 512,
+ init_std: Optional[float] = 0.02,
+ is_encoder_decoder: Optional[bool] = True,
+ add_cross_attention: Optional[bool] = True,
+ decoder_start_token_id: Optional[int] = 0,
+ ngram: Optional[int] = 2,
+ num_buckets: Optional[int] = 32,
+ relative_max_distance: Optional[int] = 128,
+ disable_ngram_loss: Optional[bool] = False,
+ eps: Optional[float] = 0.0,
+ use_cache: Optional[bool] = True,
+ pad_token_id: Optional[int] = 0,
+ bos_token_id: Optional[int] = 1,
+ eos_token_id: Optional[int] = 2,
**kwargs
):
self.vocab_size = vocab_size
@@ -174,5 +177,6 @@ def num_hidden_layers(self) -> int:
@num_hidden_layers.setter
def num_hidden_layers(self, value):
raise NotImplementedError(
- "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and `num_decoder_layers`."
+ "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and"
+ " `num_decoder_layers`."
)
diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py
index 84fb75f3f33400..1ca6a0e49089a7 100644
--- a/src/transformers/models/prophetnet/modeling_prophetnet.py
+++ b/src/transformers/models/prophetnet/modeling_prophetnet.py
@@ -326,7 +326,8 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
@property
def decoder_cross_attentions(self):
warnings.warn(
- "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
FutureWarning,
)
return self.cross_attentions
@@ -344,7 +345,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`):
+ last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
@@ -411,7 +412,8 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
@property
def decoder_cross_attentions(self):
warnings.warn(
- "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
FutureWarning,
)
return self.cross_attentions
@@ -562,9 +564,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the pad_token_id. See ProphetNet docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the"
+ " pad_token_id. See ProphetNet docs for more information"
+ )
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
@@ -587,7 +590,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
the forward function.
"""
- def __init__(self, config: ProphetNetConfig):
+ def __init__(self, config: ProphetNetConfig) -> None:
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
@@ -639,9 +642,10 @@ def __init__(
self.num_attn_heads = num_attn_heads
self.head_dim = hidden_size // num_attn_heads
- assert (
- self.head_dim * num_attn_heads == hidden_size
- ), "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and `config.num_decoder_attention_heads`"
+ assert self.head_dim * num_attn_heads == hidden_size, (
+ "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
+ " `config.num_decoder_attention_heads`"
+ )
self.key_proj = nn.Linear(hidden_size, hidden_size)
self.value_proj = nn.Linear(hidden_size, hidden_size)
@@ -708,7 +712,10 @@ def forward(
batch_size * self.num_attn_heads,
tgt_len,
src_len,
- ), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}"
+ ), (
+ f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
+ f" {attn_weights.shape}"
+ )
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0:
@@ -717,7 +724,10 @@ def forward(
self.num_attn_heads * batch_size,
1,
src_len,
- ), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
+ ), (
+ "`attention_mask` should be `None` or of shape attention_mask.size() =="
+ f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
+ )
if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask
@@ -735,9 +745,10 @@ def forward(
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
@@ -757,7 +768,10 @@ def forward(
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
- ), f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}"
+ ), (
+ f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
+ f" shape {attn_output.size()}"
+ )
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
@@ -847,7 +861,10 @@ def forward(
batch_size,
ngram_sequence_length,
hidden_size,
- ], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}"
+ ], (
+ f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
+ f" {hidden_states.shape}"
+ )
# project
query_states = self.query_proj(hidden_states)
@@ -916,9 +933,10 @@ def forward(
).type_as(main_attn_weights)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
@@ -979,9 +997,10 @@ def forward(
).type_as(predict_attn_weights)
if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_attn_heads,
- ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
@@ -1388,7 +1407,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
embeddings instead of randomly initialized word embeddings.
"""
- def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None):
+ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
super().__init__(config)
self.ngram = config.ngram
@@ -1559,9 +1578,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
# grad cannot be kept because tensor is sliced
@@ -1749,7 +1769,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
@@ -1813,7 +1833,7 @@ def forward(
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
@@ -1935,7 +1955,7 @@ def forward(
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
@@ -2081,11 +2101,12 @@ def get_decoder(self):
@add_start_docstrings(
- "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal language modeling.",
+ "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal"
+ " language modeling.",
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
config.is_decoder = True
@@ -2202,7 +2223,7 @@ def forward(
>>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
>>> labels = tokenizer_dec(
... "us rejects charges against its ambassador in bolivia", return_tensors="pt"
- >>> ).input_ids
+ ... ).input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])
>>> loss = outputs.loss
@@ -2320,7 +2341,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
classes.
"""
- def __init__(self, config):
+ def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.decoder = ProphetNetDecoder(config)
diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py
index 5bc3951b7969c4..c7725974039043 100644
--- a/src/transformers/models/prophetnet/tokenization_prophetnet.py
+++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py
@@ -15,7 +15,7 @@
import collections
import os
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@@ -28,7 +28,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/prophetnet-large-uncased": "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
+ "microsoft/prophetnet-large-uncased": (
+ "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer"
+ ),
}
}
@@ -109,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
def __init__(
self,
- vocab_file,
- do_lower_case=True,
- do_basic_tokenize=True,
- never_split=None,
- unk_token="[UNK]",
- sep_token="[SEP]",
- x_sep_token="[X_SEP]",
- pad_token="[PAD]",
- mask_token="[MASK]",
- tokenize_chinese_chars=True,
- strip_accents=None,
+ vocab_file: str,
+ do_lower_case: Optional[bool] = True,
+ do_basic_tokenize: Optional[bool] = True,
+ never_split: Optional[Iterable] = None,
+ unk_token: Optional[str] = "[UNK]",
+ sep_token: Optional[str] = "[SEP]",
+ x_sep_token: Optional[str] = "[X_SEP]",
+ pad_token: Optional[str] = "[PAD]",
+ mask_token: Optional[str] = "[MASK]",
+ tokenize_chinese_chars: Optional[bool] = True,
+ strip_accents: Optional[bool] = None,
**kwargs
):
super().__init__(
@@ -139,8 +141,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -175,21 +177,24 @@ def _tokenize(self, text):
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
- def _convert_token_to_id(self, token):
+ def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
- def _convert_id_to_token(self, index):
+ def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
- def convert_tokens_to_string(self, tokens):
+ def convert_tokens_to_string(self, tokens: str):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def get_special_tokens_mask(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ self,
+ token_ids_0: List[int],
+ token_ids_1: Optional[List[int]] = None,
+ already_has_special_tokens: Optional[bool] = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
diff --git a/src/transformers/models/qdqbert/__init__.py b/src/transformers/models/qdqbert/__init__.py
index 28fb61c2193c2c..60f03338f48022 100644
--- a/src/transformers/models/qdqbert/__init__.py
+++ b/src/transformers/models/qdqbert/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
-}
+_import_structure = {"configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_qdqbert"] = [
"QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"QDQBertForMaskedLM",
@@ -44,7 +47,12 @@
if TYPE_CHECKING:
from .configuration_qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_qdqbert import (
QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
QDQBertForMaskedLM,
diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py
index ecd9fe73a95d49..0e90dba4fd393d 100755
--- a/src/transformers/models/qdqbert/modeling_qdqbert.py
+++ b/src/transformers/models/qdqbert/modeling_qdqbert.py
@@ -19,6 +19,7 @@
import math
import os
import warnings
+from typing import Optional
import torch
import torch.utils.checkpoint
@@ -61,8 +62,9 @@
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
except OSError:
logger.error(
- "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. "
- "Please try to reinstall it following the instructions here: https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
+ "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
+ " following the instructions here:"
+ " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
)
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
@@ -173,8 +175,13 @@ def __init__(self, config):
)
def forward(
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
- ):
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -501,7 +508,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -952,7 +960,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1325,7 +1333,8 @@ def forward(
if "next_sentence_label" in kwargs:
warnings.warn(
- "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
diff --git a/src/transformers/models/rag/__init__.py b/src/transformers/models/rag/__init__.py
index 00e88f7c0abdf0..7798e8a415745c 100644
--- a/src/transformers/models/rag/__init__.py
+++ b/src/transformers/models/rag/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -27,7 +27,12 @@
"tokenization_rag": ["RagTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
@@ -35,7 +40,12 @@
"RagTokenForGeneration",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
@@ -49,10 +59,20 @@
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py
index 642b13c580c06e..1d6a62b2013d11 100644
--- a/src/transformers/models/rag/modeling_rag.py
+++ b/src/transformers/models/rag/modeling_rag.py
@@ -336,9 +336,10 @@ def from_pretrained_question_encoder_generator(
# by the value of the flag `is_generator` that we need to set correctly.
question_encoder = kwargs_question_encoder.pop("model", None)
if question_encoder is None:
- assert (
- question_encoder_pretrained_model_name_or_path is not None
- ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined"
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
from ..auto.modeling_auto import AutoModel
if "config" not in kwargs_question_encoder:
@@ -357,9 +358,10 @@ def from_pretrained_question_encoder_generator(
generator = kwargs_generator.pop("model", None)
if generator is None:
- assert (
- generator_pretrained_model_name_or_path is not None
- ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined"
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
from ..auto.modeling_auto import AutoModelForSeq2SeqLM
if "config" not in kwargs_generator:
@@ -654,23 +656,27 @@ def forward(
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
).squeeze(1)
else:
- assert (
- context_input_ids is not None
- ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
assert (
doc_scores is not None
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
- assert (
- doc_scores.shape[1] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# Decoder input without context documents
if decoder_input_ids is not None:
@@ -826,7 +832,7 @@ def forward(
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(
... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
- >>> ).squeeze(1)
+ ... ).squeeze(1)
>>> # 3. Forward to generator
>>> outputs = model(
... context_input_ids=docs_dict["context_input_ids"],
@@ -1022,12 +1028,14 @@ def generate(
new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
else: # input_ids is None, need context_input_ids/mask and doc_scores
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
individual_input_ids = generator_input_ids.repeat(
num_candidates, 1
@@ -1293,7 +1301,7 @@ def forward(
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(
... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
- >>> ).squeeze(1)
+ ... ).squeeze(1)
>>> # 3. Forward to generator
>>> outputs = model(
... context_input_ids=docs_dict["context_input_ids"],
@@ -1567,9 +1575,10 @@ def generate(
1
)
- assert (
- context_input_ids.shape[0] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# batch_size
batch_size = context_input_ids.shape[0] // n_docs
diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py
index 30f50a29ff404d..3d0ad31db8add8 100644
--- a/src/transformers/models/rag/modeling_tf_rag.py
+++ b/src/transformers/models/rag/modeling_tf_rag.py
@@ -321,9 +321,10 @@ def from_pretrained_question_encoder_generator(
# by the value of the flag `is_generator` that we need to set correctly.
question_encoder = kwargs_question_encoder.pop("model", None)
if question_encoder is None:
- assert (
- question_encoder_pretrained_model_name_or_path is not None
- ), "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined"
+ assert question_encoder_pretrained_model_name_or_path is not None, (
+ "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
+ " be defined"
+ )
from ..auto.modeling_tf_auto import TFAutoModel
@@ -343,9 +344,10 @@ def from_pretrained_question_encoder_generator(
generator = kwargs_generator.pop("generator", None)
if generator is None:
- assert (
- generator_pretrained_model_name_or_path is not None
- ), "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be defined"
+ assert generator_pretrained_model_name_or_path is not None, (
+ "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
+ " to be defined"
+ )
from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
@@ -632,23 +634,27 @@ def call(
)
else:
- assert (
- context_input_ids is not None
- ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_input_ids is not None, (
+ "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
+ " set a retriever using the `set_retriever(...)` function."
+ )
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
assert (
doc_scores is not None
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
- assert (
- doc_scores.shape[1] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (doc_scores.shape[1] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
# Decoder input without context documents
if decoder_input_ids is not None:
@@ -1149,9 +1155,10 @@ def generate(
)
doc_scores = tf.squeeze(doc_scores, axis=1)
- assert (
- context_input_ids.shape[0] % n_docs
- ) == 0, f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is {context_input_ids.shape[0]}."
+ assert (context_input_ids.shape[0] % n_docs) == 0, (
+ f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
+ f" {context_input_ids.shape[0]}."
+ )
batch_size = context_input_ids.shape[0] // n_docs
@@ -1286,9 +1293,10 @@ def shift_tokens_right(self, input_ids, start_token_id=None):
if start_token_id is None:
start_token_id = self.generator.config.decoder_start_token_id
- assert (
- start_token_id is not None
- ), "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as generator, see Bart docs for more information"
+ assert start_token_id is not None, (
+ "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as"
+ " generator, see Bart docs for more information"
+ )
pad_token_id = self.generator.config.pad_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
@@ -1745,12 +1753,14 @@ def generate(
new_input_ids = tf.tile(input_ids[index : index + 1], (num_candidates, 1))
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
else: # input_ids is None, need context_input_ids/mask and doc_scores
- assert (
- context_attention_mask is not None
- ), "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
- assert (
- doc_scores is not None
- ), "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
+ assert context_attention_mask is not None, (
+ "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
+ " can set a retriever using the `set_retriever(...)` function."
+ )
+ assert doc_scores is not None, (
+ "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
+ " retriever using the `set_retriever(...)` function."
+ )
individual_input_ids = tf.tile(
generator_input_ids, (num_candidates, 1)
diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py
index f39fc48d27c8b3..7a3c5635f24f9b 100644
--- a/src/transformers/models/rag/retrieval_rag.py
+++ b/src/transformers/models/rag/retrieval_rag.py
@@ -354,7 +354,7 @@ class RagRetriever:
>>> dataset = (
... ...
- >>> ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
+ ... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
>>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
>>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
diff --git a/src/transformers/models/realm/__init__.py b/src/transformers/models/realm/__init__.py
index db113dbd5b29a1..2464c0ae27d965 100644
--- a/src/transformers/models/realm/__init__.py
+++ b/src/transformers/models/realm/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -25,10 +25,20 @@
"tokenization_realm": ["RealmTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_realm_fast"] = ["RealmTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_realm"] = [
"REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
"RealmEmbedder",
@@ -46,10 +56,20 @@
from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig
from .tokenization_realm import RealmTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_realm import RealmTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_realm import (
REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
RealmEmbedder,
diff --git a/src/transformers/models/realm/configuration_realm.py b/src/transformers/models/realm/configuration_realm.py
index d3383bd897c5d7..8d816a736e7a33 100644
--- a/src/transformers/models/realm/configuration_realm.py
+++ b/src/transformers/models/realm/configuration_realm.py
@@ -21,10 +21,18 @@
logger = logging.get_logger(__name__)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py
index eec4fb2b7debcc..e6de31a4cb5eda 100644
--- a/src/transformers/models/realm/modeling_realm.py
+++ b/src/transformers/models/realm/modeling_realm.py
@@ -189,8 +189,13 @@ def __init__(self, config):
)
def forward(
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
- ):
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
@@ -253,7 +258,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -497,7 +502,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -1078,7 +1084,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1361,7 +1367,8 @@ def forward(
@add_start_docstrings(
- "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood loss.",
+ "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood"
+ " loss.",
REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
@@ -1782,7 +1789,7 @@ def forward(
... add_special_tokens=False,
... return_token_type_ids=False,
... return_attention_mask=False,
- >>> ).input_ids
+ ... ).input_ids
>>> reader_output, predicted_answer_ids = model(**question_ids, answer_ids=answer_ids, return_dict=False)
>>> predicted_answer = tokenizer.decode(predicted_answer_ids)
diff --git a/src/transformers/models/realm/tokenization_realm.py b/src/transformers/models/realm/tokenization_realm.py
index 426b5d775cf931..63295826d462b6 100644
--- a/src/transformers/models/realm/tokenization_realm.py
+++ b/src/transformers/models/realm/tokenization_realm.py
@@ -30,10 +30,18 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
@@ -165,8 +173,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/realm/tokenization_realm_fast.py b/src/transformers/models/realm/tokenization_realm_fast.py
index 87580baa228b62..f61fa8418ed2ba 100644
--- a/src/transformers/models/realm/tokenization_realm_fast.py
+++ b/src/transformers/models/realm/tokenization_realm_fast.py
@@ -31,24 +31,48 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
+ ),
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
},
"tokenizer_file": {
- "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
- "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
- "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
- "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
- "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
- "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json",
- "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
- "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json",
+ "google/realm-cc-news-pretrained-embedder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
+ ),
+ "google/realm-cc-news-pretrained-encoder": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
+ ),
+ "google/realm-cc-news-pretrained-scorer": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
+ ),
+ "google/realm-cc-news-pretrained-openqa": (
+ "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-nq-openqa": (
+ "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-nq-reader": (
+ "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-wq-openqa": (
+ "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
+ ),
+ "google/realm-orqa-wq-reader": (
+ "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/reformer/__init__.py b/src/transformers/models/reformer/__init__.py
index 3c6130301b53e8..979074bcc728b6 100644
--- a/src/transformers/models/reformer/__init__.py
+++ b/src/transformers/models/reformer/__init__.py
@@ -18,20 +18,39 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
-}
+_import_structure = {"configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_reformer"] = ["ReformerTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_reformer_fast"] = ["ReformerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_reformer"] = [
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"ReformerAttention",
@@ -48,13 +67,28 @@
if TYPE_CHECKING:
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_reformer import ReformerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_reformer_fast import ReformerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
diff --git a/src/transformers/models/reformer/configuration_reformer.py b/src/transformers/models/reformer/configuration_reformer.py
index d481b3b1376846..ea2a1abd08252e 100755
--- a/src/transformers/models/reformer/configuration_reformer.py
+++ b/src/transformers/models/reformer/configuration_reformer.py
@@ -22,7 +22,9 @@
logger = logging.get_logger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json",
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json"
+ ),
"google/reformer-enwik8": "https://huggingface.co/google/reformer-enwik8/resolve/main/config.json",
}
diff --git a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
index 2e2e3f3a60dd93..f25e166ef917cb 100755
--- a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
+++ b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
@@ -210,8 +210,10 @@ def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained Reformer model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained Reformer model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py
index 089481f8542ee1..8430f3a62c0dc5 100755
--- a/src/transformers/models/reformer/modeling_reformer.py
+++ b/src/transformers/models/reformer/modeling_reformer.py
@@ -380,9 +380,10 @@ def forward(
# check if cache shall be used and that hidden states are already cached
if do_cached_attention:
- assert (
- sequence_length == 1
- ), f"At the moment, auto-regressive language generation is only possible one word at a time. Make sure that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed."
+ assert sequence_length == 1, (
+ "At the moment, auto-regressive language generation is only possible one word at a time. Make sure"
+ f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed."
+ )
past_buckets = past_buckets_states[0]
past_states = past_buckets_states[1]
@@ -505,9 +506,10 @@ def forward(
)
if self.chunk_length is None:
- assert (
- self.num_chunks_before == 0 and self.num_chunks_after == 0
- ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
+ assert self.num_chunks_before == 0 and self.num_chunks_after == 0, (
+ "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and"
+ " `config.num_chunks_before` are set to 0."
+ )
elif do_cached_attention and past_buckets is not None:
# use max sequence length
sorted_bucket_idx_per_hash = sorted_bucket_idx
@@ -577,7 +579,10 @@ def forward(
self.num_attention_heads,
sequence_length,
self.attention_head_size,
- ), "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length, config.attention_head_size]`."
+ ), (
+ "out_vectors have be of shape `[batch_size, config.num_attention_heads, sequence_length,"
+ " config.attention_head_size]`."
+ )
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
@@ -891,7 +896,10 @@ def _get_relevant_hid_states_and_buckets(
self.num_attention_heads,
num_hashes,
sequence_length,
- ), f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but has shape {bucket_idx.shape}."
+ ), (
+ f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but"
+ f" has shape {bucket_idx.shape}."
+ )
# find indices of new bucket indices
relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero()
@@ -925,12 +933,20 @@ def _get_relevant_hid_states_and_buckets(
assert (
relevant_hidden_states.shape[2]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes
- ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`, there are {relevant_hidden_states.shape[2]} `hidden_states`."
+ ), (
+ "There should be"
+ f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`,"
+ f" there are {relevant_hidden_states.shape[2]} `hidden_states`."
+ )
assert (
relevant_bucket_idx_chunk.shape[-1]
== (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length
- ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`."
+ ), (
+ "There should be"
+ f" {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are"
+ f" {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`."
+ )
return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets
@@ -1054,9 +1070,10 @@ def forward(
# check if cache shall be used and that hidden states are already cached
if use_cache and past_buckets_states[1] is not None:
- assert (
- past_buckets_states[0] is None
- ), "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching hidden_states_and_buckets."
+ assert past_buckets_states[0] is None, (
+ "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching"
+ " hidden_states_and_buckets."
+ )
key_value_hidden_states = self._retrieve_relevant_hidden_states(
past_buckets_states[1], self.chunk_length, self.num_chunks_before
)
@@ -1092,9 +1109,10 @@ def forward(
), f"last dim of query_key_vectors is {value_vectors.shape[-1]} but should be {self.attention_head_size}."
if self.chunk_length is None:
- assert (
- self.num_chunks_before == 0 and self.num_chunks_after == 0
- ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0."
+ assert self.num_chunks_before == 0 and self.num_chunks_after == 0, (
+ "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and"
+ " `config.num_chunks_before` are set to 0."
+ )
# normalize key vectors
key_vectors = key_vectors / torch.sqrt(
@@ -1514,9 +1532,10 @@ def backward_pass(
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
- assert (
- self.training
- ), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
+ assert self.training, (
+ "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the"
+ " model into training mode."
+ )
with torch.enable_grad():
next_attn_output.requires_grad = True
@@ -1957,7 +1976,7 @@ class ReformerModelWithLMHeadOutput(ModelOutput):
@add_start_docstrings(
- "The bare Reformer Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare Reformer Model transformer outputting raw hidden-stateswithout any specific head on top.",
REFORMER_START_DOCSTRING,
)
class ReformerModel(ReformerPreTrainedModel):
@@ -2176,12 +2195,14 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
- assert (
- "local" not in self.config.attn_layers or config.local_num_chunks_after == 0
- ), f"If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not {config.local_num_chunks_after}."
- assert (
- "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0
- ), f"If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not {config.lsh_num_chunks_after}."
+ assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, (
+ "If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not"
+ f" {config.local_num_chunks_after}."
+ )
+ assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, (
+ "If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not"
+ f" {config.lsh_num_chunks_after}."
+ )
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
@@ -2296,9 +2317,10 @@ def _reorder_cache(self, past, beam_idx):
class ReformerForMaskedLM(ReformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
- assert (
- not config.is_decoder
- ), "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
+ assert not config.is_decoder, (
+ "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional"
+ " self-attention."
+ )
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py
index 8c75dda15e705a..d5d73f3e451f9f 100644
--- a/src/transformers/models/reformer/tokenization_reformer.py
+++ b/src/transformers/models/reformer/tokenization_reformer.py
@@ -34,7 +34,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ )
}
}
diff --git a/src/transformers/models/reformer/tokenization_reformer_fast.py b/src/transformers/models/reformer/tokenization_reformer_fast.py
index e6a84837915983..e9c6a61993d09a 100644
--- a/src/transformers/models/reformer/tokenization_reformer_fast.py
+++ b/src/transformers/models/reformer/tokenization_reformer_fast.py
@@ -38,10 +38,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
+ )
},
"tokenizer_file": {
- "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
+ "google/reformer-crime-and-punishment": (
+ "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
+ )
},
}
diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py
index 185ead37b640e3..2de85e0cc19b18 100644
--- a/src/transformers/models/regnet/__init__.py
+++ b/src/transformers/models/regnet/__init__.py
@@ -19,13 +19,17 @@
# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable
-_import_structure = {
- "configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"],
-}
+_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_regnet"] = [
"REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"RegNetForImageClassification",
@@ -37,7 +41,12 @@
if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_regnet import (
REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
RegNetForImageClassification,
diff --git a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
index 8024ef67920114..a43967d0095d2b 100644
--- a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
+++ b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
@@ -277,7 +277,10 @@ def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]:
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/regnet/convert_regnet_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_to_pytorch.py
index 96e4ab700ab5e2..9bb0ba0f053283 100644
--- a/src/transformers/models/regnet/convert_regnet_to_pytorch.py
+++ b/src/transformers/models/regnet/convert_regnet_to_pytorch.py
@@ -84,7 +84,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced) and self.raise_if_mismatch:
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -431,7 +432,10 @@ def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Mo
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py
index 0ebd05a25ce15c..8d8098caf1ea14 100644
--- a/src/transformers/models/regnet/modeling_regnet.py
+++ b/src/transformers/models/regnet/modeling_regnet.py
@@ -100,7 +100,7 @@ def forward(self, hidden_state):
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
-class RegNetShortCut(nn.Sequential):
+class RegNetShortCut(nn.Module):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
@@ -111,6 +111,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
class RegNetSELayer(nn.Module):
"""
diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py
index fb5defeee5d074..10af6c4d27f3be 100644
--- a/src/transformers/models/rembert/__init__.py
+++ b/src/transformers/models/rembert/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
-}
+_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_rembert"] = ["RemBertTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_rembert_fast"] = ["RemBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_rembert"] = [
"REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"RemBertForCausalLM",
@@ -53,7 +67,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_rembert"] = [
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRemBertForCausalLM",
@@ -71,13 +90,28 @@
if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_rembert import RemBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_rembert_fast import RemBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_rembert import (
REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
RemBertForCausalLM,
@@ -92,7 +126,12 @@
load_tf_weights_in_rembert,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py
index 589c40bdcb98d6..732d75c5cc2b3d 100644
--- a/src/transformers/models/rembert/configuration_rembert.py
+++ b/src/transformers/models/rembert/configuration_rembert.py
@@ -21,7 +21,7 @@
logger = logging.get_logger(__name__)
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
+ "google/rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
# See all RemBERT models at https://huggingface.co/models?filter=rembert
}
@@ -80,16 +80,17 @@ class RemBertConfig(PretrainedConfig):
Example:
```python
+ >>> from transformers import RemBertModel, RemBertConfig
- ```
+ >>> # Initializing a RemBERT rembert style configuration
+ >>> configuration = RemBertConfig()
- >>> from transformers import RemBertModel, RemBertConfig >>> # Initializing a RemBERT rembert style
- configuration >>> configuration = RemBertConfig()
+ >>> # Initializing a model from the rembert style configuration
+ >>> model = RemBertModel(configuration)
- >>> # Initializing a model from the rembert style configuration >>> model = RemBertModel(configuration)
-
- >>> # Accessing the model configuration >>> configuration = model.config
- """
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
model_type = "rembert"
def __init__(
diff --git a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
index 2a3c497d37a895..4c3d53e789de01 100755
--- a/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,10 @@ def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_fil
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained RemBERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained RemBERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py
index dc6f88f886ad69..b6c20cb689d86f 100755
--- a/src/transformers/models/rembert/modeling_rembert.py
+++ b/src/transformers/models/rembert/modeling_rembert.py
@@ -460,7 +460,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -785,7 +786,7 @@ class PreTrainedModel
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -857,7 +858,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -938,7 +939,7 @@ def set_output_embeddings(self, new_embeddings):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1183,7 +1184,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1280,7 +1281,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1373,7 +1374,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1452,7 +1453,7 @@ def __init__(self, config):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
diff --git a/src/transformers/models/rembert/modeling_tf_rembert.py b/src/transformers/models/rembert/modeling_tf_rembert.py
index c039f263503792..2e25dafed48302 100644
--- a/src/transformers/models/rembert/modeling_tf_rembert.py
+++ b/src/transformers/models/rembert/modeling_tf_rembert.py
@@ -414,8 +414,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -938,7 +938,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -1041,7 +1041,7 @@ def get_lm_head(self) -> tf.keras.layers.Layer:
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1131,7 +1131,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
@unpack_inputs
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@@ -1262,7 +1262,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1352,7 +1352,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1471,7 +1471,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@@ -1550,7 +1550,7 @@ def __init__(self, config: RemBertConfig, *inputs, **kwargs):
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
- checkpoint="rembert",
+ checkpoint="google/rembert",
output_type=TFQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py
index 8a839228f87243..e1c0a9ec84d603 100644
--- a/src/transformers/models/resnet/__init__.py
+++ b/src/transformers/models/resnet/__init__.py
@@ -18,14 +18,19 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
- "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
+ "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"]
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_resnet"] = [
"RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ResNetForImageClassification",
@@ -35,9 +40,14 @@
if TYPE_CHECKING:
- from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
-
- if is_torch_available():
+ from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetForImageClassification,
diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py
index 8e5f6e656d1fba..9bfc694bb1442a 100644
--- a/src/transformers/models/resnet/configuration_resnet.py
+++ b/src/transformers/models/resnet/configuration_resnet.py
@@ -14,7 +14,13 @@
# limitations under the License.
""" ResNet model configuration"""
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -89,3 +95,20 @@ def __init__(
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
+
+
+class ResNetOnnxConfig(OnnxConfig):
+
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "sequence"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-3
diff --git a/src/transformers/models/resnet/convert_resnet_to_pytorch.py b/src/transformers/models/resnet/convert_resnet_to_pytorch.py
index 60973ecdec0623..55a865ed593620 100644
--- a/src/transformers/models/resnet/convert_resnet_to_pytorch.py
+++ b/src/transformers/models/resnet/convert_resnet_to_pytorch.py
@@ -81,7 +81,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced):
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -173,7 +174,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--model_name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported resnet* architecture,"
+ " currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py
index f2f555d7f5196c..61ed3c98871589 100644
--- a/src/transformers/models/resnet/modeling_resnet.py
+++ b/src/transformers/models/resnet/modeling_resnet.py
@@ -52,7 +52,7 @@
]
-class ResNetConvLayer(nn.Sequential):
+class ResNetConvLayer(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
@@ -63,8 +63,14 @@ def __init__(
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ return hidden_state
-class ResNetEmbeddings(nn.Sequential):
+
+class ResNetEmbeddings(nn.Module):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
@@ -76,8 +82,13 @@ def __init__(self, config: ResNetConfig):
)
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ def forward(self, input: Tensor) -> Tensor:
+ embedding = self.embedder(input)
+ embedding = self.pooler(embedding)
+ return embedding
+
-class ResNetShortCut(nn.Sequential):
+class ResNetShortCut(nn.Module):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
@@ -88,6 +99,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
class ResNetBasicLayer(nn.Module):
"""
@@ -148,7 +164,7 @@ def forward(self, hidden_state):
return hidden_state
-class ResNetStage(nn.Sequential):
+class ResNetStage(nn.Module):
"""
A ResNet stage composed by stacked layers.
"""
@@ -171,6 +187,12 @@ def __init__(
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
)
+ def forward(self, input: Tensor) -> Tensor:
+ hidden_state = input
+ for layer in self.layers:
+ hidden_state = layer(hidden_state)
+ return hidden_state
+
class ResNetEncoder(nn.Module):
def __init__(self, config: ResNetConfig):
diff --git a/src/transformers/models/retribert/__init__.py b/src/transformers/models/retribert/__init__.py
index e4d383780b667d..34cfadfe1a8743 100644
--- a/src/transformers/models/retribert/__init__.py
+++ b/src/transformers/models/retribert/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -26,10 +26,20 @@
"tokenization_retribert": ["RetriBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_retribert_fast"] = ["RetriBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_retribert"] = [
"RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"RetriBertModel",
@@ -41,10 +51,20 @@
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .tokenization_retribert import RetriBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_retribert_fast import RetriBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_retribert import (
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
RetriBertModel,
diff --git a/src/transformers/models/retribert/configuration_retribert.py b/src/transformers/models/retribert/configuration_retribert.py
index 1e4feb2a69090c..23172cf40ec7d3 100644
--- a/src/transformers/models/retribert/configuration_retribert.py
+++ b/src/transformers/models/retribert/configuration_retribert.py
@@ -22,7 +22,9 @@
# TODO: upload to AWS
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/config.json",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/retribert/modeling_retribert.py b/src/transformers/models/retribert/modeling_retribert.py
index 8470aea7ae59d2..5a12c962e29230 100644
--- a/src/transformers/models/retribert/modeling_retribert.py
+++ b/src/transformers/models/retribert/modeling_retribert.py
@@ -18,6 +18,7 @@
import math
+from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
@@ -85,7 +86,7 @@ def _init_weights(self, module):
RETRIBERT_START_DOCSTRING,
)
class RetriBertModel(RetriBertPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: RetriBertConfig) -> None:
super().__init__(config)
self.projection_dim = config.projection_dim
@@ -117,7 +118,7 @@ def embed_sentences_checkpointed(
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
- attention_mask, input_shape, device
+ attention_mask, input_shape
)
# define function for checkpointing
@@ -173,8 +174,13 @@ def embed_answers(
return self.project_doc(a_reps)
def forward(
- self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
- ):
+ self,
+ input_ids_query: torch.LongTensor,
+ attention_mask_query: Optional[torch.FloatTensor],
+ input_ids_doc: torch.LongTensor,
+ attention_mask_doc: Optional[torch.FloatTensor],
+ checkpoint_batch_size: int = -1,
+ ) -> torch.FloatTensor:
r"""
Args:
input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
diff --git a/src/transformers/models/retribert/tokenization_retribert.py b/src/transformers/models/retribert/tokenization_retribert.py
index 934054e6050fce..b61c0634406a54 100644
--- a/src/transformers/models/retribert/tokenization_retribert.py
+++ b/src/transformers/models/retribert/tokenization_retribert.py
@@ -24,7 +24,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt"
+ ),
}
}
@@ -42,7 +44,7 @@ class RetriBertTokenizer(BertTokenizer):
r"""
Constructs a RetriBERT tokenizer.
- [`RetroBertTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting
+ [`RetriBertTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting
and wordpiece.
Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
diff --git a/src/transformers/models/retribert/tokenization_retribert_fast.py b/src/transformers/models/retribert/tokenization_retribert_fast.py
index 43cc3837214b3c..3451d1224a7a18 100644
--- a/src/transformers/models/retribert/tokenization_retribert_fast.py
+++ b/src/transformers/models/retribert/tokenization_retribert_fast.py
@@ -25,10 +25,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "yjernite/retribert-base-uncased": "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json",
+ "yjernite/retribert-base-uncased": (
+ "https://huggingface.co/yjernite/retribert-base-uncased/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py
index 0e1070dbe7e540..2429ba113e8a19 100644
--- a/src/transformers/models/roberta/__init__.py
+++ b/src/transformers/models/roberta/__init__.py
@@ -18,7 +18,14 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -26,10 +33,20 @@
"tokenization_roberta": ["RobertaTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_roberta_fast"] = ["RobertaTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_roberta"] = [
"ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"RobertaForCausalLM",
@@ -42,7 +59,12 @@
"RobertaPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_roberta"] = [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForCausalLM",
@@ -56,8 +78,14 @@
"TFRobertaPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_roberta"] = [
+ "FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
@@ -72,10 +100,20 @@
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaOnnxConfig
from .tokenization_roberta import RobertaTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_roberta_fast import RobertaTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM,
@@ -88,7 +126,12 @@
RobertaPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
@@ -102,8 +145,14 @@
TFRobertaPreTrainedModel,
)
- if is_flax_available():
- from .modeling_tf_roberta import (
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_roberta import (
+ FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py
index 7f195bc70867c7..84bf15da6d8614 100644
--- a/src/transformers/models/roberta/modeling_flax_roberta.py
+++ b/src/transformers/models/roberta/modeling_flax_roberta.py
@@ -20,14 +20,16 @@
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
-from jax.random import PRNGKey
from ...modeling_flax_outputs import (
- FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxBaseModelOutputWithPooling,
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
@@ -174,13 +176,15 @@ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, dete
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
class FlaxRobertaSelfAttention(nn.Module):
config: RobertaConfig
+ causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
- : {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -199,30 +203,113 @@ def setup(self):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states: Optional[jnp.array] = None,
+ init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
):
- head_dim = self.config.hidden_size // self.config.num_attention_heads
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.query(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.key(key_value_states)
+ value_states = self.value(key_value_states)
+ else:
+ # self_attention
+ key_states = self.key(hidden_states)
+ value_states = self.value(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
- query_states = self.query(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- value_states = self.value(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- key_states = self.key(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
@@ -282,10 +369,11 @@ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
config: RobertaConfig
+ causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
- self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
+ self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(
@@ -293,6 +381,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states=None,
+ init_cache=False,
deterministic=True,
output_attentions: bool = False,
):
@@ -303,6 +393,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ key_value_states=key_value_states,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -363,27 +455,46 @@ class FlaxRobertaLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
- self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
+ self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
):
+ # Self Attention
attention_outputs = self.attention(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=encoder_hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
@@ -391,6 +502,8 @@ def __call__(
if output_attentions:
outputs += (attention_outputs[1],)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attention_outputs[1],)
return outputs
@@ -409,6 +522,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -416,13 +532,14 @@ def __call__(
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
- {head_mask.shape[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
@@ -433,6 +550,9 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -442,6 +562,9 @@ def __call__(
if output_attentions:
all_attentions += (layer_outputs[1],)
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -450,8 +573,11 @@ def __call__(
if not return_dict:
return tuple(v for v in outputs if v is not None)
- return FlaxBaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
)
@@ -468,6 +594,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -477,6 +606,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -603,9 +735,26 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
- random_params = self.module.init(
- rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
- )["params"]
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
+ )
+
+ random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
@@ -617,6 +766,26 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
else:
return random_params
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
@@ -625,12 +794,15 @@ def __call__(
token_type_ids=None,
position_ids=None,
head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
params: dict = None,
- dropout_rng: PRNGKey = None,
+ dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ past_key_values: dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -656,19 +828,60 @@ def __call__(
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
- return self.module.apply(
- {"params": params or self.params},
- jnp.array(input_ids, dtype="i4"),
- jnp.array(attention_mask, dtype="i4"),
- jnp.array(token_type_ids, dtype="i4"),
- jnp.array(position_ids, dtype="i4"),
- jnp.array(head_mask, dtype="i4"),
- not train,
- output_attentions,
- output_hidden_states,
- return_dict,
- rngs=rngs,
- )
+ inputs = {"params": params or self.params}
+
+ if self.config.add_cross_attention:
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxRobertaAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ else:
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ )
+
+ return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
@@ -686,9 +899,12 @@ def __call__(
self,
input_ids,
attention_mask,
- token_type_ids: Optional[np.ndarray] = None,
- position_ids: Optional[np.ndarray] = None,
- head_mask: Optional[np.ndarray] = None,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -710,6 +926,9 @@ def __call__(
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@@ -723,11 +942,12 @@ def __call__(
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
- return FlaxBaseModelOutputWithPooling(
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
)
@@ -1101,3 +1321,108 @@ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
+
+
+class FlaxRobertaForCausalLMModule(nn.Module):
+ config: RobertaConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.roberta(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ else:
+ shared_embedding = None
+
+ # Compute the prediction scores
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
+ autoregressive tasks.
+ """,
+ ROBERTA_START_DOCSTRING,
+)
+class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
+ module_class = FlaxRobertaForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxRobertaForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py
index 2b7d2f87859350..0b57b1031e537b 100644
--- a/src/transformers/models/roberta/modeling_roberta.py
+++ b/src/transformers/models/roberta/modeling_roberta.py
@@ -182,7 +182,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -426,7 +426,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -817,7 +818,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -919,7 +920,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1080,7 +1081,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MaskedLMOutput]:
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1193,7 +1194,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutput]:
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1290,7 +1291,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MultipleChoiceModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1390,7 +1391,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1496,7 +1497,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py
index 7c39b7334a46c9..a320664bcea57f 100644
--- a/src/transformers/models/roberta/modeling_tf_roberta.py
+++ b/src/transformers/models/roberta/modeling_tf_roberta.py
@@ -463,8 +463,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
diff --git a/src/transformers/models/roberta/tokenization_roberta.py b/src/transformers/models/roberta/tokenization_roberta.py
index 0d87615c15693f..10b28125e92bce 100644
--- a/src/transformers/models/roberta/tokenization_roberta.py
+++ b/src/transformers/models/roberta/tokenization_roberta.py
@@ -39,7 +39,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/vocab.json",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/vocab.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt",
@@ -47,7 +49,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/merges.txt",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/merges.txt",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt"
+ ),
},
}
@@ -320,7 +324,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/roberta/tokenization_roberta_fast.py b/src/transformers/models/roberta/tokenization_roberta_fast.py
index 7b774f69f19a29..cb055430b13675 100644
--- a/src/transformers/models/roberta/tokenization_roberta_fast.py
+++ b/src/transformers/models/roberta/tokenization_roberta_fast.py
@@ -35,7 +35,9 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/vocab.json",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/vocab.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json"
+ ),
},
"merges_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt",
@@ -43,15 +45,21 @@
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/merges.txt",
"roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/merges.txt",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt",
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt"
+ ),
},
"tokenizer_file": {
"roberta-base": "https://huggingface.co/roberta-base/resolve/main/tokenizer.json",
"roberta-large": "https://huggingface.co/roberta-large/resolve/main/tokenizer.json",
"roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/tokenizer.json",
"distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/tokenizer.json",
- "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/tokenizer.json",
- "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/tokenizer.json",
+ "roberta-base-openai-detector": (
+ "https://huggingface.co/roberta-base-openai-detector/resolve/main/tokenizer.json"
+ ),
+ "roberta-large-openai-detector": (
+ "https://huggingface.co/roberta-large-openai-detector/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/roformer/__init__.py b/src/transformers/models/roformer/__init__.py
index ec99c5a3b86ad3..909259ead6017f 100644
--- a/src/transformers/models/roformer/__init__.py
+++ b/src/transformers/models/roformer/__init__.py
@@ -17,7 +17,14 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {
@@ -25,10 +32,20 @@
"tokenization_roformer": ["RoFormerTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_roformer_fast"] = ["RoFormerTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_roformer"] = [
"ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"RoFormerForCausalLM",
@@ -44,7 +61,12 @@
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_roformer"] = [
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRoFormerForCausalLM",
@@ -59,7 +81,12 @@
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_roformer"] = [
"FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaxRoFormerForMaskedLM",
@@ -76,10 +103,20 @@
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
from .tokenization_roformer import RoFormerTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_roformer_fast import RoFormerTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
RoFormerForCausalLM,
@@ -94,7 +131,12 @@
load_tf_weights_in_roformer,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_roformer import (
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRoFormerForCausalLM,
@@ -108,7 +150,12 @@
TFRoFormerPreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_roformer import (
FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaxRoFormerForMaskedLM,
diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py
index 2c5de2bbbe262d..ea547ca52d1bd9 100644
--- a/src/transformers/models/roformer/configuration_roformer.py
+++ b/src/transformers/models/roformer/configuration_roformer.py
@@ -27,10 +27,18 @@
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json"
+ ),
# See all RoFormer models at https://huggingface.co/models?filter=roformer
}
diff --git a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
index 33edf59f6bfd74..0ab8b671d0752e 100755
--- a/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py
@@ -51,8 +51,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained BERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained BERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py
index 37dd72966646f7..011f1610488da2 100644
--- a/src/transformers/models/roformer/modeling_flax_roformer.py
+++ b/src/transformers/models/roformer/modeling_flax_roformer.py
@@ -180,8 +180,8 @@ class FlaxRoFormerSelfAttention(nn.Module):
def setup(self) -> None:
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
- : {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
@@ -456,8 +456,8 @@ def __call__(
if head_mask is not None:
if head_mask.shape[0] != (len(self.layers)):
raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for \
- {head_mask.shape[0]}."
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
)
for i, layer in enumerate(self.layers):
diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py
index fe746971504697..353b1b39217aec 100644
--- a/src/transformers/models/roformer/modeling_roformer.py
+++ b/src/transformers/models/roformer/modeling_roformer.py
@@ -699,8 +699,8 @@ class RoFormerPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = []
_keys_to_ignore_on_load_unexpected = [
- r"roformer\.embeddings_project\.weight",
- r"roformer\.embeddings_project\.bias",
+ r"roformer.embeddings_project.weight",
+ r"roformer.embeddings_project.bias",
]
def _init_weights(self, module):
@@ -900,7 +900,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
diff --git a/src/transformers/models/roformer/tokenization_roformer.py b/src/transformers/models/roformer/tokenization_roformer.py
index e5e3728c03fc3a..ac1efc72d08978 100644
--- a/src/transformers/models/roformer/tokenization_roformer.py
+++ b/src/transformers/models/roformer/tokenization_roformer.py
@@ -31,10 +31,18 @@
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt"
+ ),
}
}
@@ -144,8 +152,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py
index 26c37d4580f020..7b2cab56886200 100644
--- a/src/transformers/models/roformer/tokenization_roformer_fast.py
+++ b/src/transformers/models/roformer/tokenization_roformer_fast.py
@@ -27,16 +27,24 @@
logger = logging.get_logger(__name__)
-VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
- "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
- "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
- "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
+ "junnyu/roformer_chinese_char_small": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_chinese_char_base": (
+ "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_discriminator": (
+ "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt"
+ ),
+ "junnyu/roformer_small_generator": (
+ "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/segformer/__init__.py b/src/transformers/models/segformer/__init__.py
index fed4e8127cbd21..1ce4ecb07a9cfc 100644
--- a/src/transformers/models/segformer/__init__.py
+++ b/src/transformers/models/segformer/__init__.py
@@ -17,17 +17,25 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"],
-}
+_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_segformer"] = ["SegformerFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_segformer"] = [
"SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegformerDecodeHead",
@@ -42,10 +50,20 @@
if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_segformer import SegformerFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_segformer import (
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
SegformerDecodeHead,
diff --git a/src/transformers/models/segformer/configuration_segformer.py b/src/transformers/models/segformer/configuration_segformer.py
index fa54c62c227c1f..faec5d6c4c9fb8 100644
--- a/src/transformers/models/segformer/configuration_segformer.py
+++ b/src/transformers/models/segformer/configuration_segformer.py
@@ -23,7 +23,9 @@
logger = logging.get_logger(__name__)
SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "nvidia/segformer-b0-finetuned-ade-512-512": "https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/main/config.json",
+ "nvidia/segformer-b0-finetuned-ade-512-512": (
+ "https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/main/config.json"
+ ),
# See all SegFormer models at https://huggingface.co/models?filter=segformer
}
@@ -122,8 +124,8 @@ def __init__(
if "reshape_last_stage" in kwargs and kwargs["reshape_last_stage"] is False:
warnings.warn(
- "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be removed, "
- "as the behaviour will default to that of reshape_last_stage = True.",
+ "Reshape_last_stage is set to False in this config. This argument is deprecated and will soon be"
+ " removed, as the behaviour will default to that of reshape_last_stage = True.",
FutureWarning,
)
diff --git a/src/transformers/models/segformer/feature_extraction_segformer.py b/src/transformers/models/segformer/feature_extraction_segformer.py
index c706c559af3c1e..0a9ae01ef121e5 100644
--- a/src/transformers/models/segformer/feature_extraction_segformer.py
+++ b/src/transformers/models/segformer/feature_extraction_segformer.py
@@ -158,8 +158,9 @@ def __call__(
if not valid_segmentation_maps:
raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
+ " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
+ " examples)."
)
is_batched = bool(
diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py
index d8989e340c3ca9..55ac976b354422 100755
--- a/src/transformers/models/segformer/modeling_segformer.py
+++ b/src/transformers/models/segformer/modeling_segformer.py
@@ -112,7 +112,7 @@ def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
diff --git a/src/transformers/models/sew/__init__.py b/src/transformers/models/sew/__init__.py
index 4ee9380137d102..bfe39bea1bdcf5 100644
--- a/src/transformers/models/sew/__init__.py
+++ b/src/transformers/models/sew/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"],
-}
+_import_structure = {"configuration_sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_sew"] = [
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWForCTC",
@@ -36,7 +39,12 @@
if TYPE_CHECKING:
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_sew import (
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWForCTC,
diff --git a/src/transformers/models/sew/configuration_sew.py b/src/transformers/models/sew/configuration_sew.py
index ad6a6afa699212..e9665baeede149 100644
--- a/src/transformers/models/sew/configuration_sew.py
+++ b/src/transformers/models/sew/configuration_sew.py
@@ -76,13 +76,13 @@ class SEWConfig(PretrainedConfig):
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
diff --git a/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
index 6449288810f4fc..58c0338a850d0f 100644
--- a/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
@@ -67,9 +67,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -137,28 +138,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index ac2a6293cb95a3..1ead29326139f0 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -489,7 +489,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -505,7 +506,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -526,7 +528,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
diff --git a/src/transformers/models/sew_d/__init__.py b/src/transformers/models/sew_d/__init__.py
index bc577400405733..905bfb0f5b6834 100644
--- a/src/transformers/models/sew_d/__init__.py
+++ b/src/transformers/models/sew_d/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"],
-}
+_import_structure = {"configuration_sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_sew_d"] = [
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
"SEWDForCTC",
@@ -36,7 +39,12 @@
if TYPE_CHECKING:
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_sew_d import (
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWDForCTC,
diff --git a/src/transformers/models/sew_d/configuration_sew_d.py b/src/transformers/models/sew_d/configuration_sew_d.py
index 996338cb0f0537..b078623cfda692 100644
--- a/src/transformers/models/sew_d/configuration_sew_d.py
+++ b/src/transformers/models/sew_d/configuration_sew_d.py
@@ -94,13 +94,13 @@ class SEWDConfig(PretrainedConfig):
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 2, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
diff --git a/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
index e6529eea04dda3..942add470b9c68 100644
--- a/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
@@ -69,9 +69,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -141,28 +142,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py
index a297e4c7b25b31..defdd715846a61 100644
--- a/src/transformers/models/sew_d/modeling_sew_d.py
+++ b/src/transformers/models/sew_d/modeling_sew_d.py
@@ -261,7 +261,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
- mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
@@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
- rmask = ~(mask.bool())
+ rmask = ~(mask.to(torch.bool))
- output = input.masked_fill(rmask, float("-inf"))
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
@@ -557,7 +557,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
- output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
@@ -711,7 +711,7 @@ def __init__(self, config):
def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
@@ -792,7 +792,7 @@ def forward(
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
diff --git a/src/transformers/models/speech_encoder_decoder/__init__.py b/src/transformers/models/speech_encoder_decoder/__init__.py
index a040990864a998..4eea93eacddc45 100644
--- a/src/transformers/models/speech_encoder_decoder/__init__.py
+++ b/src/transformers/models/speech_encoder_decoder/__init__.py
@@ -18,26 +18,44 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
-_import_structure = {
- "configuration_speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
-}
+_import_structure = {"configuration_speech_encoder_decoder": ["SpeechEncoderDecoderConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_encoder_decoder"] = ["SpeechEncoderDecoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_speech_encoder_decoder"] = ["FlaxSpeechEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_encoder_decoder import SpeechEncoderDecoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
else:
diff --git a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
index ca3e4966aaf96b..8b648f8e21bc3c 100644
--- a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
@@ -77,7 +77,8 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
- f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and"
+ f" `decoder` sub-configurations are passed, but only {kwargs}"
)
encoder_config = kwargs.pop("encoder")
diff --git a/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
index 3c25ab706f4e19..8680f96e50d561 100644
--- a/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
+++ b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
@@ -75,9 +75,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -147,28 +148,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py b/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
index 40433bba1344be..0a4bc48dea3246 100644
--- a/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
+++ b/src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
@@ -77,9 +77,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -153,28 +154,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
index faabeae17fa490..cd304fa0c0a890 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py
@@ -357,10 +357,10 @@ def __init__(
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# make sure input & output embeddings are not tied
@@ -389,7 +389,8 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
- f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+ f" and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
@@ -627,7 +628,7 @@ def _decoder_forward(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
- encoder_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
**kwargs,
)
@@ -713,7 +714,8 @@ def __call__(
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError(
- "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+ " be specified as an input argument."
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
@@ -895,10 +897,9 @@ def from_encoder_decoder_pretrained(
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
index 1dbba59f9ef326..8b717641bb827b 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
@@ -199,10 +199,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -221,11 +221,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -410,10 +412,9 @@ def from_encoder_decoder_pretrained(
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -599,8 +600,8 @@ def prepare_inputs_for_generation(
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. "
- "Please use the respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the SpeechEncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py
index 0cccf667213633..20eba2bf6a2d48 100644
--- a/src/transformers/models/speech_to_text/__init__.py
+++ b/src/transformers/models/speech_to_text/__init__.py
@@ -17,26 +17,45 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_speech_to_text": [
- "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "Speech2TextConfig",
- ],
+ "configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
-if is_speech_available():
+try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
if is_sentencepiece_available():
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_speech_to_text"] = [
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSpeech2TextForConditionalGeneration",
@@ -44,7 +63,12 @@
"TFSpeech2TextPreTrainedModel",
]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_to_text"] = [
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
@@ -56,16 +80,31 @@
if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_speech_to_text import Speech2TextTokenizer
- if is_speech_available():
+ try:
+ if not is_speech_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available():
from .processing_speech_to_text import Speech2TextProcessor
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration,
@@ -73,7 +112,12 @@
TFSpeech2TextPreTrainedModel,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py
index f08bbf51e1b2b6..f12be50b538cef 100644
--- a/src/transformers/models/speech_to_text/configuration_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json",
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json"
+ ),
# See all Speech2Text models at https://huggingface.co/models?filter=speech_to_text
}
diff --git a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
index df8bc485364f3f..6c1cd993fe46c4 100644
--- a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
+++ b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py
@@ -102,7 +102,8 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
]
):
raise ValueError(
- f"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing, but all the following weights are missing {missing}"
+ "Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing,"
+ f" but all the following weights are missing {missing}"
)
if tie_embeds:
diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
index e6ff52f183607c..4294c48c71f0ee 100644
--- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
@@ -190,8 +190,9 @@ def __call__(
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}. "
- f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}."
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
index 7c2e1835370a62..78fac2cac1ecc3 100755
--- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
@@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -91,7 +91,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module):
@@ -292,7 +292,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -308,7 +309,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -329,7 +331,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -625,9 +628,9 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
- If you want to change padding behavior, you should read [`modeling_speech_to_text._prepare_decoder_inputs`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
+ If you want to change padding behavior, you should read
+ [`modeling_speech_to_text._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@@ -885,7 +888,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1024,9 +1027,10 @@ def forward(
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
- assert attn_mask.size()[0] == (
- len(self.layers)
- ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
@@ -1041,7 +1045,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -1247,8 +1252,8 @@ def forward(
class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
- r"encoder\.version",
- r"decoder\.version",
+ r"encoder.version",
+ r"decoder.version",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
index 8980636c3b32b5..f61ddd7fed0c90 100755
--- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
@@ -331,7 +331,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -341,7 +344,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -357,7 +363,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -374,7 +383,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
@@ -856,7 +868,10 @@ def call(
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
- message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.",
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
)
for idx, encoder_layer in enumerate(self.layers):
@@ -1065,7 +1080,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
- message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
)
for idx, decoder_layer in enumerate(self.layers):
@@ -1387,7 +1405,7 @@ def call(
>>> input_features = processor(
... ds["speech"][0], sampling_rate=16000, return_tensors="tf"
- >>> ).input_features # Batch size 1
+ ... ).input_features # Batch size 1
>>> generated_ids = model.generate(input_features)
>>> transcription = processor.batch_decode(generated_ids)
diff --git a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
index 7d77c945ced8c2..e1bc681499f7cb 100644
--- a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py
@@ -36,10 +36,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json",
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json"
+ ),
},
"spm_file": {
- "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model"
+ "facebook/s2t-small-librispeech-asr": (
+ "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model"
+ )
},
}
diff --git a/src/transformers/models/speech_to_text_2/__init__.py b/src/transformers/models/speech_to_text_2/__init__.py
index d4ea8d037a0d04..645a397460937b 100644
--- a/src/transformers/models/speech_to_text_2/__init__.py
+++ b/src/transformers/models/speech_to_text_2/__init__.py
@@ -17,20 +17,28 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_speech_to_text_2": [
- "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "Speech2Text2Config",
- ],
+ "configuration_speech_to_text_2": ["SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2Text2Config"],
"processing_speech_to_text_2": ["Speech2Text2Processor"],
"tokenization_speech_to_text_2": ["Speech2Text2Tokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_speech_to_text_2"] = [
"SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2Text2ForCausalLM",
@@ -43,7 +51,12 @@
from .processing_speech_to_text_2 import Speech2Text2Processor
from .tokenization_speech_to_text_2 import Speech2Text2Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_speech_to_text_2 import (
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2Text2ForCausalLM,
diff --git a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
index d27bad73c73ca2..c1b3cf7e4c7fb7 100644
--- a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/config.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/config.json"
+ ),
# See all Speech2Text models at https://huggingface.co/models?filter=speech2text2
}
diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
index dccbd2adf48be3..d90c4c87b6bd91 100755
--- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
@@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -71,7 +71,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
@@ -238,7 +238,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -254,7 +255,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -275,7 +277,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -492,7 +495,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -633,7 +636,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -649,7 +653,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -735,7 +740,8 @@ def forward(self, *args, **kwargs):
@add_start_docstrings(
- "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].",
+ "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of"
+ " [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].",
SPEECH_TO_TEXT_2_START_DOCSTRING,
)
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
diff --git a/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
index 51d5c31ec9912d..3365dfe382ae6f 100644
--- a/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py
@@ -33,13 +33,19 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/vocab.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/vocab.json"
+ ),
},
"tokenizer_config_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/tokenizer_config.json",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/tokenizer_config.json"
+ ),
},
"merges_file": {
- "facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/merges.txt",
+ "facebook/s2t-wav2vec2-large-en-de": (
+ "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/merges.txt"
+ ),
},
}
@@ -244,7 +250,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
if self.bpe_ranks is None:
diff --git a/src/transformers/models/splinter/__init__.py b/src/transformers/models/splinter/__init__.py
index 6a2308bbf53574..9f056d7200a197 100644
--- a/src/transformers/models/splinter/__init__.py
+++ b/src/transformers/models/splinter/__init__.py
@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
@@ -25,13 +25,24 @@
"tokenization_splinter": ["SplinterTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_splinter_fast"] = ["SplinterTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_splinter"] = [
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForQuestionAnswering",
+ "SplinterForPreTraining",
"SplinterLayer",
"SplinterModel",
"SplinterPreTrainedModel",
@@ -42,12 +53,23 @@
from .configuration_splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig
from .tokenization_splinter import SplinterTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_splinter_fast import SplinterTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ SplinterForPreTraining,
SplinterForQuestionAnswering,
SplinterLayer,
SplinterModel,
diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py
index 4d695b3137e1f0..ae8ba4fa34b0c7 100755
--- a/src/transformers/models/splinter/modeling_splinter.py
+++ b/src/transformers/models/splinter/modeling_splinter.py
@@ -16,6 +16,7 @@
import math
+from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@@ -24,7 +25,7 @@
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
-from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
+from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
@@ -126,7 +127,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -370,7 +371,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -710,7 +712,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -939,3 +941,171 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+
+@dataclass
+class SplinterForPreTrainingOutput(ModelOutput):
+ """
+ Class for outputs of Splinter as a span selection model.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: torch.FloatTensor = None
+ end_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@add_start_docstrings(
+ """
+ Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
+ is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
+ instead.
+ """,
+ SPLINTER_START_DOCSTRING,
+)
+class SplinterForPreTraining(SplinterPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.splinter = SplinterModel(config)
+ self.splinter_qass = QuestionAwareSpanSelectionHead(config)
+ self.question_token_id = config.question_token_id
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(
+ SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ question_positions: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, SplinterForPreTrainingOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
+ The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
+ num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
+ the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
+ sequence_length)`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if question_positions is None and start_positions is not None and end_positions is not None:
+ raise TypeError("question_positions must be specified in order to calculate the loss")
+
+ elif question_positions is None and input_ids is None:
+ raise TypeError("question_positions must be specified when input_embeds is used")
+
+ elif question_positions is None:
+ question_positions = self._prepare_question_positions(input_ids)
+
+ outputs = self.splinter(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ batch_size, sequence_length, dim = sequence_output.size()
+ # [batch_size, num_questions, sequence_length]
+ start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
+
+ num_questions = question_positions.size(1)
+ if attention_mask is not None:
+ attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
+ batch_size, num_questions, sequence_length
+ )
+ start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0
+ end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0
+
+ total_loss = None
+ # [batch_size, num_questions, sequence_length]
+ if start_positions is not None and end_positions is not None:
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ start_positions.clamp_(0, max(0, sequence_length - 1))
+ end_positions.clamp_(0, max(0, sequence_length - 1))
+
+ # Ignore zero positions in the loss. Splinter never predicts zero
+ # during pretraining and zero is used for padding question
+ # tokens as well as for start and end positions of padded
+ # question tokens.
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
+ start_loss = loss_fct(
+ start_logits.view(batch_size * num_questions, sequence_length),
+ start_positions.view(batch_size * num_questions),
+ )
+ end_loss = loss_fct(
+ end_logits.view(batch_size * num_questions, sequence_length),
+ end_positions.view(batch_size * num_questions),
+ )
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return SplinterForPreTrainingOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
+ rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
+ num_questions = torch.bincount(rows)
+ positions = torch.full(
+ (input_ids.size(0), num_questions.max()),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ cols = torch.cat([torch.arange(n) for n in num_questions])
+ positions[rows, cols] = flat_positions
+ return positions
diff --git a/src/transformers/models/splinter/tokenization_splinter.py b/src/transformers/models/splinter/tokenization_splinter.py
index 9649da03f9f18b..f600566e6e9411 100644
--- a/src/transformers/models/splinter/tokenization_splinter.py
+++ b/src/transformers/models/splinter/tokenization_splinter.py
@@ -153,8 +153,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
diff --git a/src/transformers/models/squeezebert/__init__.py b/src/transformers/models/squeezebert/__init__.py
index 433b9f93343f25..9f758bebe0247c 100644
--- a/src/transformers/models/squeezebert/__init__.py
+++ b/src/transformers/models/squeezebert/__init__.py
@@ -18,18 +18,32 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
- "configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"],
+ "configuration_squeezebert": [
+ "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "SqueezeBertConfig",
+ "SqueezeBertOnnxConfig",
+ ],
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_squeezebert_fast"] = ["SqueezeBertTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_squeezebert"] = [
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SqueezeBertForMaskedLM",
@@ -44,13 +58,27 @@
if TYPE_CHECKING:
- from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
+ from .configuration_squeezebert import (
+ SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ SqueezeBertConfig,
+ SqueezeBertOnnxConfig,
+ )
from .tokenization_squeezebert import SqueezeBertTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertForMaskedLM,
diff --git a/src/transformers/models/squeezebert/configuration_squeezebert.py b/src/transformers/models/squeezebert/configuration_squeezebert.py
index 5a77495fc7044c..41b47ff5750ec3 100644
--- a/src/transformers/models/squeezebert/configuration_squeezebert.py
+++ b/src/transformers/models/squeezebert/configuration_squeezebert.py
@@ -13,17 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" SqueezeBERT model configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/config.json",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/config.json"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/config.json",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/config.json",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/config.json"
+ ),
}
@@ -150,3 +157,20 @@ def __init__(
self.post_attention_groups = post_attention_groups
self.intermediate_groups = intermediate_groups
self.output_groups = output_groups
+
+
+# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert
+class SqueezeBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py
index b8cdfe16a9f4bd..210531772984a2 100644
--- a/src/transformers/models/squeezebert/modeling_squeezebert.py
+++ b/src/transformers/models/squeezebert/modeling_squeezebert.py
@@ -612,7 +612,7 @@ def forward(
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert.py b/src/transformers/models/squeezebert/tokenization_squeezebert.py
index e41e576455fe6a..72d927eccafb59 100644
--- a/src/transformers/models/squeezebert/tokenization_squeezebert.py
+++ b/src/transformers/models/squeezebert/tokenization_squeezebert.py
@@ -24,9 +24,13 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt"
+ ),
}
}
diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py b/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
index 58708030f9f39e..5ee656e5a8d5e4 100644
--- a/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
+++ b/src/transformers/models/squeezebert/tokenization_squeezebert_fast.py
@@ -25,14 +25,24 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/vocab.txt"
+ ),
"squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/vocab.txt",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt",
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/vocab.txt"
+ ),
},
"tokenizer_file": {
- "squeezebert/squeezebert-uncased": "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/tokenizer.json",
- "squeezebert/squeezebert-mnli": "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/tokenizer.json",
- "squeezebert/squeezebert-mnli-headless": "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/tokenizer.json",
+ "squeezebert/squeezebert-uncased": (
+ "https://huggingface.co/squeezebert/squeezebert-uncased/resolve/main/tokenizer.json"
+ ),
+ "squeezebert/squeezebert-mnli": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli/resolve/main/tokenizer.json"
+ ),
+ "squeezebert/squeezebert-mnli-headless": (
+ "https://huggingface.co/squeezebert/squeezebert-mnli-headless/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/swin/__init__.py b/src/transformers/models/swin/__init__.py
index b8cb65d08b3ab4..33a9bddeea7332 100644
--- a/src/transformers/models/swin/__init__.py
+++ b/src/transformers/models/swin/__init__.py
@@ -18,15 +18,18 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
-_import_structure = {
- "configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"],
-}
+_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_swin"] = [
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinForImageClassification",
@@ -35,11 +38,29 @@
"SwinPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_swin"] = [
+ "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TFSwinForImageClassification",
+ "TFSwinForMaskedImageModeling",
+ "TFSwinModel",
+ "TFSwinPreTrainedModel",
+ ]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinForImageClassification,
@@ -48,6 +69,19 @@
SwinPreTrainedModel,
)
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinPreTrainedModel,
+ )
else:
import sys
diff --git a/src/transformers/models/swin/configuration_swin.py b/src/transformers/models/swin/configuration_swin.py
index 9956482b9ab778..878a73e9208b5e 100644
--- a/src/transformers/models/swin/configuration_swin.py
+++ b/src/transformers/models/swin/configuration_swin.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/swin-tiny-patch4-window7-224": "https://huggingface.co/microsoft/swin-tiny-patch4-window7-224/resolve/main/config.json",
+ "microsoft/swin-tiny-patch4-window7-224": (
+ "https://huggingface.co/microsoft/swin-tiny-patch4-window7-224/resolve/main/config.json"
+ ),
# See all Swin models at https://huggingface.co/models?filter=swin
}
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index 51a19ab73b8ccb..be46b8dc2f8fce 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -18,7 +18,7 @@
import collections.abc
import math
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
- batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
@@ -272,7 +272,9 @@ def __init__(self, config, use_mask_token=False):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, pixel_values, bool_masked_pos=None):
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
@@ -317,7 +319,7 @@ def maybe_pad(self, pixel_values, height, width):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
- def forward(self, pixel_values):
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, _, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
@@ -342,7 +344,7 @@ class SwinPatchMerging(nn.Module):
Normalization layer class.
"""
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
@@ -357,7 +359,7 @@ def maybe_pad(self, input_feature, height, width):
return input_feature
- def forward(self, input_feature, input_dimensions):
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
@@ -400,7 +402,7 @@ def __init__(self, config, dim, num_heads):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
- f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})"
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
@@ -433,16 +435,16 @@ def __init__(self, config, dim, num_heads):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- output_attentions=False,
- ):
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
@@ -486,7 +488,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -499,7 +501,7 @@ def __init__(self, config, dim):
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def forward(self, hidden_states, input_tensor):
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
@@ -531,7 +533,13 @@ def prune_heads(self, heads):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
- def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@@ -547,7 +555,7 @@ def __init__(self, config, dim):
else:
self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
@@ -559,7 +567,7 @@ def __init__(self, config, dim):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
@@ -621,7 +629,13 @@ def maybe_pad(self, hidden_states, height, width):
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
- def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
@@ -703,7 +717,13 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d
self.pointing = False
- def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
@@ -752,13 +772,13 @@ def __init__(self, config, grid_size):
def forward(
self,
- hidden_states,
- input_dimensions,
- head_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
@@ -920,13 +940,13 @@ class PreTrainedModel
)
def forward(
self,
- pixel_values=None,
- bool_masked_pos=None,
- head_mask=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SwinModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -977,7 +997,8 @@ def forward(
@add_start_docstrings(
- "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM `__.",
+ "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
+ " `__.",
SWIN_START_DOCSTRING,
)
class SwinForMaskedImageModeling(SwinPreTrainedModel):
@@ -999,13 +1020,13 @@ def __init__(self, config):
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
- pixel_values=None,
- bool_masked_pos=None,
- head_mask=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
@@ -1047,11 +1068,10 @@ def forward(
)
sequence_output = outputs[0]
-
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
- height = width = int(sequence_length**0.5)
+ height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
@@ -1115,13 +1135,13 @@ def __init__(self, config):
)
def forward(
self,
- pixel_values=None,
- head_mask=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ pixel_values: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SwinImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py
new file mode 100644
index 00000000000000..5b9ecbeccfafb2
--- /dev/null
+++ b/src/transformers/models/swin/modeling_tf_swin.py
@@ -0,0 +1,1464 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 Swin Transformer model."""
+
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import ACT2FN
+from ...modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_swin import SwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "SwinConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/swin-tiny-patch4-window7-224",
+ # See all Swin models at https://huggingface.co/models?filter=swin
+]
+
+# to_2tuple, drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow
+# implementations of PyTorch functionalities in the timm library.
+
+
+@dataclass
+class TFSwinEncoderOutput(ModelOutput):
+ """
+ Swin encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinModelOutput(ModelOutput):
+ """
+ Swin model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: tf.Tensor = None
+ pooler_output: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinMaskedImageModelingOutput(ModelOutput):
+ """
+ Swin masked image model outputs.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Masked image modeling (MLM) loss.
+ logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed pixel values.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[tf.Tensor] = None
+ logits: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFSwinImageClassifierOutput(ModelOutput):
+ """
+ Swin outputs for image classification.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
+ `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[tf.Tensor] = None
+ logits: tf.Tensor = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+ reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.to_2tuple
+def to_2tuple(x) -> Tuple[Any, Any]:
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = shape_list(input_feature)
+ input_feature = tf.reshape(
+ input_feature,
+ (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels),
+ )
+ windows = tf.transpose(input_feature, (0, 1, 3, 2, 4, 5))
+ windows = tf.reshape(windows, (-1, window_size, window_size, num_channels))
+ return windows
+
+
+def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int) -> tf.Tensor:
+ """
+ Merges windows to produce higher resolution features.
+ """
+ x = shape_list(windows)[0]
+ y = tf.cast(height * width / window_size / window_size, tf.int32)
+ batch_size = int(x / y)
+ windows = tf.reshape(
+ windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
+ )
+ windows = tf.transpose(windows, (0, 1, 3, 2, 4, 5))
+ windows = tf.reshape(windows, (batch_size, height, width, -1))
+ return windows
+
+
+def drop_path(
+ input: tf.Tensor, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+) -> tf.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ input_shape = shape_list(input)
+ ndim = len(input_shape)
+ shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = tf.random.uniform(shape)
+ random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor /= keep_prob
+ return input * random_tensor
+
+
+class TFSwinEmbeddings(tf.keras.layers.Layer):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.patch_embeddings = TFSwinPatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.embed_dim,
+ name="patch_embeddings",
+ )
+ self.num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.embed_dim = config.embed_dim
+ self.use_mask_token = use_mask_token
+ self.use_absolute_embeddings = config.use_absolute_embeddings
+
+ self.norm = tf.keras.layers.LayerNormalization(name="norm", epsilon=1e-5)
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
+
+ def build(self, input_shape: tf.TensorShape) -> None:
+ if self.use_mask_token:
+ self.mask_token = self.add_weight(shape=(1, 1, self.embed_dim), initializer="zeros", name="mask_token")
+ else:
+ self.mask_token = None
+
+ if self.use_absolute_embeddings:
+ self.position_embeddings = self.add_weight(
+ (1, self.num_patches + 1, self.embed_dim), initializer="zeros", name="positional_embeddings"
+ )
+ else:
+ self.position_embeddings = None
+ super().build(input_shape)
+
+ def call(
+ self, pixel_values: tf.Tensor, bool_masked_pos: bool = None, training: bool = False
+ ) -> Tuple[tf.Tensor, Tuple[int, int]]:
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
+ embeddings = self.norm(embeddings, training=training)
+ batch_size, seq_len, _ = shape_list(embeddings)
+
+ if bool_masked_pos is not None:
+ mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
+ mask_tokens = tf.repeat(mask_tokens, seq_len, 1)
+ # replace the masked visual tokens by mask_tokens
+ mask = tf.expand_dims(bool_masked_pos, -1)
+ mask = tf.cast(mask, mask_tokens.dtype)
+
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings, training=training)
+
+ return embeddings, output_dimensions
+
+
+class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ image_size = to_2tuple(image_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = tf.keras.layers.Conv2D(
+ filters=embed_dim, kernel_size=self.patch_size, strides=self.patch_size, padding="valid", name="projection"
+ )
+
+ def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ if width % self.patch_size[1] != 0:
+ pad_values = ((0, 0), (0, 0), (0, 0), (0, self.patch_size[1] - width % self.patch_size[1]))
+ pixel_values = tf.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = ((0, 0), (0, 0), (0, self.patch_size[0] - height % self.patch_size[0]), (0, 0))
+ pixel_values = tf.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
+ _, _, height, width = shape_list(pixel_values)
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+
+ # B,C,H,W -> B,H,W,C
+ pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
+
+ embeddings = self.projection(pixel_values, training=training)
+
+ # B,H,W,C -> B,C,H,W
+ embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
+
+ batch_size, channels, height, width = shape_list(embeddings)
+ output_dimensions = (height, width)
+
+ embeddings = tf.reshape(embeddings, (batch_size, channels, -1))
+ embeddings = tf.transpose(embeddings, (0, 2, 1))
+ return embeddings, output_dimensions
+
+
+class TFSwinPatchMerging(tf.keras.layers.Layer):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`tf.keras.layer.Layer`, *optional*, defaults to `tf.keras.layers.LayerNormalization`):
+ Normalization layer class.
+ """
+
+ def __init__(
+ self, input_resolution: Tuple[int, int], dim: int, norm_layer: Optional[Callable] = None, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = tf.keras.layers.Dense(2 * dim, use_bias=False, name="reduction")
+ if norm_layer is None:
+ # Use same default epsilon as PyTorch
+ self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="norm")
+ else:
+ self.norm = norm_layer(name="norm")
+
+ def maybe_pad(self, input_feature: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = ((0, 0), (0, height % 2), (0, width % 2), (0, 0))
+ input_feature = tf.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, _, num_channels = shape_list(input_feature)
+
+ input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels))
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = tf.concat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = tf.reshape(
+ input_feature, (batch_size, -1, 4 * num_channels)
+ ) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature, training=training)
+ input_feature = self.reduction(input_feature, training=training)
+
+ return input_feature
+
+
+class TFSwinDropPath(tf.keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = None, scale_by_keep: bool = True, **kwargs) -> None:
+ super(TFSwinDropPath, self).__init__(**kwargs)
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor:
+ return drop_path(input, self.drop_prob, training, self.scale_by_keep)
+
+
+class TFSwinSelfAttention(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.window_size = to_2tuple(config.window_size)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = tf.range(self.window_size[0])
+ coords_w = tf.range(self.window_size[1])
+ coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
+ coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = tf.transpose(relative_coords, (1, 2, 0))
+
+ stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
+ stack_0 += self.window_size[0] - 1
+ stack_0 *= 2 * self.window_size[1] - 1
+ stack_1 += self.window_size[1] - 1
+ relative_coords = tf.stack([stack_0, stack_1], axis=2)
+ self.relative_position_index = tf.reduce_sum(relative_coords, axis=-1)
+
+ self.query = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="query",
+ )
+ self.key = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="key",
+ )
+ self.value = tf.keras.layers.Dense(
+ self.all_head_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ use_bias=config.qkv_bias,
+ name="value",
+ )
+
+ self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
+
+ def build(self, input_shape: tf.TensorShape) -> None:
+ self.relative_position_bias_table = self.add_weight(
+ shape=(((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)), self.num_attention_heads),
+ initializer="zeros",
+ name="relative_position_bias_table",
+ )
+ super().build(input_shape)
+
+ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+ new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor, ...]:
+ batch_size, dim, _ = shape_list(hidden_states)
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, (0, 1, 3, 2)))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ relative_position_bias = tf.gather(
+ self.relative_position_bias_table, tf.reshape(self.relative_position_index, (-1,))
+ )
+ relative_position_bias = tf.reshape(
+ relative_position_bias,
+ (self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1),
+ )
+
+ relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))
+ attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
+ mask_shape = shape_list(attention_mask)[0]
+ attention_scores = tf.reshape(
+ attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
+ )
+ attention_mask = tf.expand_dims(attention_mask, 1)
+ attention_mask = tf.expand_dims(attention_mask, 0)
+ attention_scores = attention_scores + attention_mask
+ attention_scores = tf.reshape(attention_scores, (-1, self.num_attention_heads, dim, dim))
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = tf.nn.softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+ context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
+ new_context_layer_shape = shape_list(context_layer)[:-2] + [
+ self.all_head_size,
+ ]
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class TFSwinSelfOutput(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(dim, name="dense")
+ self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob, name="dropout")
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSwinAttention(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.self = TFSwinSelfAttention(config, dim, num_heads, name="self")
+ self.self_output = TFSwinSelfOutput(config, dim, name="output")
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ """
+ Prunes heads of the model. See base class PreTrainedModel heads: dict of {layer_num: list of heads to prune in
+ this layer}
+ """
+ raise NotImplementedError
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions, training=training)
+ attention_output = self.self_output(self_outputs[0], hidden_states, training=training)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class TFSwinIntermediate(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(int(config.mlp_ratio * dim), name="dense")
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class TFSwinOutput(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, dim: int, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.dense = tf.keras.layers.Dense(dim, name="dense")
+ self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, "dropout")
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ return hidden_states
+
+
+class TFSwinLayer(tf.keras.layers.Layer):
+ def __init__(
+ self, config, dim, input_resolution: Tuple[int, int], num_heads: int, shift_size: int = 0, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.set_shift_and_window_size(input_resolution)
+
+ self.layernorm_before = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_before"
+ )
+ self.attention = TFSwinAttention(config, dim, num_heads, name="attention")
+ self.drop_path = (
+ TFSwinDropPath(config.drop_path_rate, name="drop_path")
+ if config.drop_path_rate > 0.0
+ else tf.identity(name="drop_path")
+ )
+ self.layernorm_after = tf.keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="layernorm_after"
+ )
+ self.intermediate = TFSwinIntermediate(config, dim, name="intermediate")
+ self.swin_output = TFSwinOutput(config, dim, name="output")
+
+ def set_shift_and_window_size(self, input_resolution: Tuple[int, int]) -> None:
+ if min(input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(input_resolution)
+
+ def get_attn_mask(self, height: int, width: int) -> Optional[tf.Tensor]:
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = tf.zeros((height, width))
+ height_slices = (
+ (0, -self.window_size),
+ (-self.window_size, -self.shift_size),
+ (-self.shift_size, -1),
+ )
+ width_slices = (
+ (0, -self.window_size),
+ (-self.window_size, -self.shift_size),
+ (-self.shift_size, -1),
+ )
+
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ indices = [
+ [i, j]
+ for i in range(height_slice[0] % height, height_slice[1] % height + 1)
+ for j in range(width_slice[0] % width, width_slice[1] % width + 1)
+ ]
+ if indices:
+ updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count
+ img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates)
+ count += 1
+
+ img_mask = tf.expand_dims(img_mask, -1)
+ img_mask = tf.expand_dims(img_mask, 0)
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
+ attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
+ attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
+ attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]:
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = tf.constant([[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]])
+ hidden_states = tf.pad(hidden_states, pad_values)
+ pad_values = tf.reshape(pad_values, (-1,))
+ return hidden_states, pad_values
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ training: bool = False,
+ ) -> tf.Tensor:
+ self.set_shift_and_window_size(input_dimensions)
+ height, width = input_dimensions
+ batch_size, _, channels = shape_list(hidden_states)
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states, training=training)
+ hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = shape_list(hidden_states)
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = tf.reshape(hidden_states_windows, (-1, self.window_size * self.window_size, channels))
+ attn_mask = self.get_attn_mask(height_pad, width_pad)
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = tf.reshape(attention_output, (-1, self.window_size, self.window_size, channels))
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = tf.roll(shifted_windows, shift=(self.shift_size, self.shift_size), axis=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :]
+
+ attention_windows = tf.reshape(attention_windows, (batch_size, height * width, channels))
+
+ hidden_states = shortcut + self.drop_path(attention_windows, training=training)
+
+ layer_output = self.layernorm_after(hidden_states, training=training)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.swin_output(layer_output, training=training)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class TFSwinStage(tf.keras.layers.Layer):
+ def __init__(
+ self,
+ config: SwinConfig,
+ dim: int,
+ input_resolution: Tuple[int, int],
+ depth: int,
+ num_heads: int,
+ drop_path: List[float],
+ downsample: Optional[Callable],
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.dim = dim
+ self.blocks = [
+ TFSwinLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ name=f"blocks.{i}",
+ )
+ for i in range(depth)
+ ]
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ input_resolution,
+ dim=dim,
+ norm_layer=partial(tf.keras.layers.LayerNormalization, epsilon=1e-5),
+ name="downsample",
+ )
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor, ...]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(layer_outputs[0], input_dimensions, training=training)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class TFSwinEncoder(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, grid_size: Tuple[int, int], **kwargs):
+ super().__init__(**kwargs)
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = list((tf.linspace(0, 1, sum(config.depths)) * config.drop_path_rate).numpy())
+ self.layers = [
+ TFSwinStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=TFSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+ name=f"layers.{i_layer}",
+ )
+ for i_layer in range(self.num_layers)
+ ]
+
+ self.gradient_checkpointing = False
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor, ...], TFSwinEncoderOutput]:
+ all_input_dimensions = ()
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = shape_list(hidden_states)
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
+ reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, training=training
+ )
+
+ hidden_states = layer_outputs[0]
+ output_dimensions = layer_outputs[1]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+ all_input_dimensions += (input_dimensions,)
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = shape_list(hidden_states)
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
+ reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TFSwinEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+class TFSwinPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SwinConfig
+ base_model_prefix = "swin"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _set_gradient_checkpointing(self, module, value=False) -> None:
+ if isinstance(module, TFSwinEncoder):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network. Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(
+ shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size),
+ dtype=tf.float32,
+ )
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+
+SWIN_START_DOCSTRING = r"""
+ This model is a Tensorflow
+ [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
+ regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`SwinConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def normalize_data_format(value: str) -> str:
+ """
+ From tensorflow addons
+ https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/utils/keras_utils.py#L71
+ """
+ if value is None:
+ value = tf.keras.backend.image_data_format()
+ data_format = value.lower()
+ if data_format not in {"channels_first", "channels_last"}:
+ raise ValueError(
+ 'The `data_format` argument must be one of "channels_first", "channels_last". Received: ' + str(value)
+ )
+ return data_format
+
+
+class AdaptiveAveragePooling1D(tf.keras.layers.Layer):
+ """
+ Args:
+ Average 1D Pooling with adaptive kernel size.
+ output_size: An integer or tuple/list of a single integer, specifying pooled_features.
+ The new size of output channels.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape `(batch, steps, channels)` while `channels_first` corresponds
+ to inputs with shape `(batch, channels, steps)`.
+ Input shape:
+ - If `data_format='channels_last'`: 3D tensor with shape `(batch, steps, channels)`.
+ - If `data_format='channels_first'`: 3D tensor with shape `(batch, channels, steps)`.
+ Output shape:
+ - If `data_format='channels_last'`: 3D tensor with shape `(batch_size, pooled_steps, channels)`.
+ - If `data_format='channels_first'`: 3D tensor with shape `(batch_size, channels, pooled_steps)`.
+
+ Adapted from [tensorflow-addon's adaptive pooling.py](
+ https://github.com/tensorflow/addons/blob/8cec33fcaaf1cf90aec7bdd55a0fcdbb251ce5c2/tensorflow_addons/layers/adaptive_pooling.py#L90-L120
+ )
+ """
+
+ def __init__(
+ self,
+ output_size: Union[int, Iterable[int]],
+ reduce_function: Callable = tf.reduce_mean,
+ data_format: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ self.data_format = normalize_data_format(data_format)
+ self.reduce_function = reduce_function
+ self.output_size = (output_size,) if isinstance(output_size, int) else tuple(output_size)
+ super().__init__(**kwargs)
+
+ def call(self, inputs: tf.Tensor, *args) -> None:
+ bins = self.output_size[0]
+ if self.data_format == "channels_last":
+ splits = tf.split(inputs, bins, axis=1)
+ splits = tf.stack(splits, axis=1)
+ out_vect = self.reduce_function(splits, axis=2)
+ else:
+ splits = tf.split(inputs, bins, axis=2)
+ splits = tf.stack(splits, axis=2)
+ out_vect = self.reduce_function(splits, axis=3)
+ return out_vect
+
+ def compute_output_shape(self, input_shape: Iterable[int]) -> tf.TensorShape:
+ input_shape = tf.TensorShape(input_shape).as_list()
+ if self.data_format == "channels_last":
+ shape = tf.TensorShape([input_shape[0], self.output_size[0], input_shape[2]])
+ else:
+ shape = tf.TensorShape([input_shape[0], input_shape[1], self.output_size[0]])
+ return shape
+
+ def get_config(self) -> Dict[str, Any]:
+ config = {
+ "output_size": self.output_size,
+ "data_format": self.data_format,
+ }
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+
+@keras_serializable
+class TFSwinMainLayer(tf.keras.layers.Layer):
+ config_class = SwinConfig
+
+ def __init__(
+ self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = TFSwinEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
+ self.encoder = TFSwinEncoder(config, self.embeddings.patch_grid, name="encoder")
+
+ self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ self.pooler = AdaptiveAveragePooling1D(output_size=(1,)) if add_pooling_layer else None
+
+ def get_input_embeddings(self) -> TFSwinPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List]):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_head_mask(self, head_mask: Optional[Any]) -> List:
+ if head_mask is not None:
+ raise NotImplementedError
+ return [None] * len(self.config.depths)
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask)
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, training=training
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output, training=training)
+
+ pooled_output = None
+ if self.pooler is not None:
+ batch_size, _, num_features = shape_list(sequence_output)
+ pooled_output = self.pooler(sequence_output)
+ pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+ return output
+
+ return TFSwinModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
+ SWIN_START_DOCSTRING,
+)
+class TFSwinModel(TFSwinPreTrainedModel):
+ def __init__(
+ self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+ ) -> None:
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.swin = TFSwinMainLayer(config, name="swin")
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSwinModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ swin_outputs = self.swin(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return swin_outputs
+
+
+class PixelShuffle(tf.keras.layers.Layer):
+ """TF layer implementation of torch.nn.PixelShuffle"""
+
+ def __init__(
+ self,
+ upscale_factor: int,
+ data_format: str = "NHWC",
+ trainable: bool = True,
+ name: str = None,
+ dtype=None,
+ dynamic: bool = False,
+ **kwargs
+ ) -> None:
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+ if upscale_factor < 2:
+ raise ValueError("upscale_factor must be an integer value >= 2")
+ self.upscale_factor = upscale_factor
+ self.data_format = data_format
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ return tf.nn.depth_to_space(x, block_size=self.upscale_factor, data_format=self.data_format)
+
+
+class TFSwinDecoder(tf.keras.layers.Layer):
+ def __init__(self, config: SwinConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.conv2d = tf.keras.layers.Conv2D(
+ filters=config.encoder_stride**2 * 3, kernel_size=1, strides=1, name="0"
+ )
+ self._block_size = config.encoder_stride
+ self.pixel_shuffle = PixelShuffle(self._block_size, name="1")
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ hidden_states = x
+ # B,C,H,W -> B,H,W,C
+ hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
+ hidden_states = self.conv2d(hidden_states)
+ batch_size, _, _, num_input_channels = shape_list(hidden_states)
+ block_size_squared = self._block_size**2
+ output_depth = int(num_input_channels / block_size_squared)
+ # When the number of output channels >= 2, PyTorch's PixelShuffle and
+ # TF's depth_to_space differ in their output as the order of channels selected for combining
+ # is a permutation of the other c.f.
+ # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
+ permutation = tf.constant(
+ [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
+ )
+ hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
+ hidden_states = self.pixel_shuffle(hidden_states)
+ # B,H,W,C -> B,C,H,W
+ hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2))
+ return hidden_states
+
+
+@add_start_docstrings(
+ "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
+ " `__.",
+ SWIN_START_DOCSTRING,
+)
+class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
+ def __init__(self, config: SwinConfig):
+ super().__init__(config)
+
+ self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin")
+
+ self.decoder = TFSwinDecoder(config, name="decoder")
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ bool_masked_pos: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple, TFSwinMaskedImageModelingOutput]:
+ r"""
+ bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoFeatureExtractor, TFSwinForMaskedImageModeling
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+ >>> model = TFSwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = feature_extractor(images=image, return_tensors="tf").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swin(
+ pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = tf.transpose(sequence_output, (0, 2, 1))
+ batch_size, num_channels, sequence_length = shape_list(sequence_output)
+ height = width = int(sequence_length**0.5)
+ sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width))
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size))
+ mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1)
+ mask = tf.repeat(mask, self.config.patch_size, 2)
+ mask = tf.expand_dims(mask, 1)
+ mask = tf.cast(mask, tf.float32)
+
+ reconstruction_loss = tf.keras.losses.mean_absolute_error(
+ # Swap axes as metric calculation reduces over the final dimension
+ tf.transpose(pixel_values, (1, 2, 3, 0)),
+ tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)),
+ )
+ reconstruction_loss = tf.expand_dims(reconstruction_loss, 0)
+ total_loss = tf.reduce_sum(reconstruction_loss * mask)
+ num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels
+ masked_im_loss = total_loss / num_masked_pixels
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[2:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return TFSwinMaskedImageModelingOutput(
+ loss=masked_im_loss,
+ logits=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ SWIN_START_DOCSTRING,
+)
+class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: SwinConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.swin = TFSwinMainLayer(config, name="swin")
+
+ # Classifier head
+ self.classifier = (
+ tf.keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else tf.identity(name="classifier")
+ )
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFSwinImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ head_mask: Optional[tf.Tensor] = None,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple[tf.Tensor, ...], TFSwinImageClassifierOutput]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.swin(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output, training=training)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSwinImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py
index 9ccb94932843df..178fb38678e7d6 100644
--- a/src/transformers/models/t5/__init__.py
+++ b/src/transformers/models/t5/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -28,17 +29,30 @@
)
-_import_structure = {
- "configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"],
-}
+_import_structure = {"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_t5"] = ["T5Tokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_t5_fast"] = ["T5TokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_t5"] = [
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"T5EncoderModel",
@@ -48,7 +62,12 @@
"load_tf_weights_in_t5",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_t5"] = [
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFT5EncoderModel",
@@ -57,7 +76,12 @@
"TFT5PreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_t5"] = [
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
@@ -68,13 +92,28 @@
if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, T5OnnxConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_t5 import T5Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_t5_fast import T5TokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
@@ -84,7 +123,12 @@
load_tf_weights_in_t5,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
@@ -93,7 +137,12 @@
TFT5PreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py
index b09539c86d705a..a2bd03dfd74cc8 100644
--- a/src/transformers/models/t5/configuration_t5.py
+++ b/src/transformers/models/t5/configuration_t5.py
@@ -116,6 +116,22 @@ def __init__(
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
+
+ act_info = self.feed_forward_proj.split("-")
+ self.dense_act_fn = act_info[-1]
+ self.is_gated_act = act_info[0] == "gated"
+
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
+ raise ValueError(
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
+ "'gated-gelu' or 'relu'"
+ )
+
+ # for backwards compatibility
+ if feed_forward_proj == "gated-gelu":
+ self.dense_act_fn = "gelu_new"
+
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
index a0020301682293..7d9a20f3b0b395 100755
--- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py
@@ -49,8 +49,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained T5 model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py
index 263412578c23d8..23f7436eab80ef 100644
--- a/src/transformers/models/t5/modeling_flax_t5.py
+++ b/src/transformers/models/t5/modeling_flax_t5.py
@@ -87,7 +87,7 @@ def __call__(self, hidden_states):
return self.weight * hidden_states
-class FlaxT5DenseReluDense(nn.Module):
+class FlaxT5DenseActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32
@@ -108,16 +108,17 @@ def setup(self):
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
+ self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.wi(hidden_states)
- hidden_states = jax.nn.relu(hidden_states)
+ hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states)
return hidden_states
-class FlaxT5DenseGatedGeluDense(nn.Module):
+class FlaxT5DenseGatedActDense(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@@ -144,10 +145,10 @@ def setup(self):
dtype=self.dtype,
)
self.dropout = nn.Dropout(self.config.dropout_rate)
- self.gelu_act = ACT2FN["gelu_new"]
+ self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic):
- hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
+ hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
@@ -160,14 +161,10 @@ class FlaxT5LayerFF(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
- if self.config.feed_forward_proj == "relu":
- self.DenseReluDense = FlaxT5DenseReluDense(self.config, dtype=self.dtype)
- elif self.config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = FlaxT5DenseGatedGeluDense(self.config, dtype=self.dtype)
+ if self.config.is_gated_act:
+ self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate)
@@ -977,7 +974,8 @@ def __call__(
if decoder_input_ids is None:
raise ValueError(
- "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
+ " here."
)
# prepare encoder inputs
@@ -1243,7 +1241,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
@add_start_docstrings(
- "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5Module(nn.Module):
@@ -1344,7 +1342,7 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="np"
- >>> ).input_ids
+ ... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> # forward pass
diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py
index 6d06f910df400a..b974ad4b2003f1 100644
--- a/src/transformers/models/t5/modeling_t5.py
+++ b/src/transformers/models/t5/modeling_t5.py
@@ -276,33 +276,33 @@ def forward(self, hidden_states):
pass
-class T5DenseReluDense(nn.Module):
+class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
- self.relu_act = ACT2FN["relu"]
+ self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
- hidden_states = self.relu_act(hidden_states)
+ hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
-class T5DenseGatedGeluDense(nn.Module):
+class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
- self.gelu_act = ACT2FN["gelu_new"]
+ self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
- hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
+ hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
@@ -313,14 +313,10 @@ def forward(self, hidden_states):
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
- if config.feed_forward_proj == "relu":
- self.DenseReluDense = T5DenseReluDense(config)
- elif config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = T5DenseGatedGeluDense(config)
+ if config.is_gated_act:
+ self.DenseReluDense = T5DenseGatedActDense(config)
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = T5DenseActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@@ -408,26 +404,24 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_postion_if_large = max_exact + (
+ relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
- relative_postion_if_large = torch.min(
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
- relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
- def compute_bias(self, query_length, key_length):
+ def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
- context_position = torch.arange(
- query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
- )[:, None]
- memory_position = torch.arange(
- key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
- )[None, :]
+ if device is None:
+ device = self.relative_attention_bias.weight.device
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
@@ -522,7 +516,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
- position_bias = self.compute_bias(real_seq_length, key_length)
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
@@ -747,6 +741,7 @@ class T5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
+ _no_split_modules = ["T5Block"]
@property
def dummy_inputs(self):
@@ -768,7 +763,9 @@ def _init_weights(self, module):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
- elif isinstance(module, T5DenseReluDense):
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, T5DenseActDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
@@ -778,7 +775,7 @@ def _init_weights(self, module):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
- elif isinstance(module, T5DenseGatedGeluDense):
+ elif isinstance(module, T5DenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
@@ -809,9 +806,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
+ " See T5 docs for more information"
+ )
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
@@ -957,7 +955,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1268,11 +1266,11 @@ def custom_forward(*inputs):
)
class T5Model(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
@@ -1375,7 +1373,7 @@ def forward(
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> # forward pass
@@ -1457,12 +1455,12 @@ def forward(
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
- r"encoder\.embed_tokens\.weight",
- r"decoder\.embed_tokens\.weight",
- r"lm_head\.weight",
+ r"encoder.embed_tokens.weight",
+ r"decoder.embed_tokens.weight",
+ r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
@@ -1583,7 +1581,7 @@ def forward(
>>> # inference
>>> input_ids = tokenizer(
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
@@ -1751,7 +1749,7 @@ def _reorder_cache(self, past, beam_idx):
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
- r"encoder\.embed_tokens\.weight",
+ r"encoder.embed_tokens.weight",
]
def __init__(self, config: T5Config):
@@ -1831,7 +1829,7 @@ def forward(
>>> model = T5EncoderModel.from_pretrained("t5-small")
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
```"""
diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py
index 2e48174a9048b7..77a65557daaa90 100644
--- a/src/transformers/models/t5/modeling_tf_t5.py
+++ b/src/transformers/models/t5/modeling_tf_t5.py
@@ -93,7 +93,7 @@ def call(self, hidden_states):
return self.weight * hidden_states
-class TFT5DenseReluDense(tf.keras.layers.Layer):
+class TFT5DenseActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
@@ -109,7 +109,7 @@ def __init__(self, config, **kwargs):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
- self.act = tf.keras.activations.relu
+ self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False):
hidden_states = self.wi(hidden_states)
@@ -119,7 +119,7 @@ def call(self, hidden_states, training=False):
return hidden_states
-class TFT5GatedGeluDense(tf.keras.layers.Layer):
+class TFT5DenseGatedActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
@@ -138,7 +138,7 @@ def __init__(self, config, **kwargs):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
- self.act = get_tf_activation("gelu_new")
+ self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
@@ -152,14 +152,11 @@ def call(self, hidden_states, training=False):
class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
- if config.feed_forward_proj == "relu":
- self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
- elif config.feed_forward_proj == "gated-gelu":
- self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
+ if config.is_gated_act:
+ self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense")
else:
- raise ValueError(
- f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
- )
+ self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
+
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
@@ -406,7 +403,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.n_heads],
- message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.n_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
@@ -899,9 +899,10 @@ def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
- assert (
- decoder_start_token_id is not None
- ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the"
+ " pad_token_id. See T5 docs for more information"
+ )
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
@@ -1102,13 +1103,15 @@ def _shift_right(self, input_ids):
@add_start_docstrings(
- "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@@ -1165,7 +1168,7 @@ def call(
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="tf"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> # forward pass
@@ -1255,8 +1258,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model
-
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@@ -1353,7 +1357,7 @@ def call(
>>> # inference
>>> inputs = tokenizer(
... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> outputs = model.generate(inputs)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you
@@ -1590,13 +1594,15 @@ def _reorder_cache(self, past, beam_idx):
@add_start_docstrings(
- "The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
+ "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
T5_START_DOCSTRING,
)
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
- self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
+ )
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
@@ -1642,7 +1648,7 @@ def call(
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="tf"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> outputs = model(input_ids)
```"""
diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py
index a356aa70c1877c..2dbc788374dcf8 100644
--- a/src/transformers/models/t5/tokenization_t5.py
+++ b/src/transformers/models/t5/tokenization_t5.py
@@ -41,6 +41,8 @@
}
}
+
+# TODO(PVP) - this should be removed in Transformers v5
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-small": 512,
"t5-base": 512,
@@ -129,8 +131,9 @@ def __init__(
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
- "In this case the additional_special_tokens must include the extra_ids tokens"
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
+ " tokens"
)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
@@ -151,6 +154,28 @@ def __init__(
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
+ @staticmethod
+ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
+ if pretrained_model_name_or_path in T5Tokenizer.max_model_input_sizes:
+ deprecated_max_model_length = T5Tokenizer.max_model_input_sizes[pretrained_model_name_or_path]
+ if init_max_model_length is not None and init_max_model_length != max_model_length:
+ return init_max_model_length
+ elif init_max_model_length is None:
+ warnings.warn(
+ "This tokenizer was incorrectly instantiated with a model max length of"
+ f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
+ " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
+ " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
+ f" {pretrained_model_name_or_path} automatically truncating your input to"
+ f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
+ f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
+ " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
+ " instantiate this tokenizer with `model_max_length` set to your preferred value.",
+ FutureWarning,
+ )
+
+ return max_model_length
+
@property
def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids
@@ -192,7 +217,8 @@ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
- f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
+ " eos tokens being added."
)
return token_ids
else:
diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py
index 46128682365f91..41ad306b74e6a8 100644
--- a/src/transformers/models/t5/tokenization_t5_fast.py
+++ b/src/transformers/models/t5/tokenization_t5_fast.py
@@ -16,6 +16,7 @@
import os
+import warnings
from shutil import copyfile
from typing import List, Optional, Tuple
@@ -50,6 +51,8 @@
},
}
+
+# TODO(PVP) - this should be removed in Transformers v5
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-small": 512,
"t5-base": 512,
@@ -123,8 +126,9 @@ def __init__(
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
- f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
- "In this case the additional_special_tokens must include the extra_ids tokens"
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
+ " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
+ " tokens"
)
super().__init__(
@@ -142,6 +146,28 @@ def __init__(
self.can_save_slow_tokenizer = False if not self.vocab_file else True
self._extra_ids = extra_ids
+ @staticmethod
+ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
+ if pretrained_model_name_or_path in T5TokenizerFast.max_model_input_sizes:
+ deprecated_max_model_length = T5TokenizerFast.max_model_input_sizes[pretrained_model_name_or_path]
+ if init_max_model_length is not None and init_max_model_length != max_model_length:
+ return init_max_model_length
+ elif init_max_model_length is None:
+ warnings.warn(
+ "This tokenizer was incorrectly instantiated with a model max length of"
+ f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
+ " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
+ " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
+ f" {pretrained_model_name_or_path} automatically truncating your input to"
+ f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
+ f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
+ " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
+ " instantiate this tokenizer with `model_max_length` set to your preferred value.",
+ FutureWarning,
+ )
+
+ return max_model_length
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
diff --git a/src/transformers/models/tapas/__init__.py b/src/transformers/models/tapas/__init__.py
index 4d3c72b85b32d1..bbfb09ea0fee68 100644
--- a/src/transformers/models/tapas/__init__.py
+++ b/src/transformers/models/tapas/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_tapas": ["TapasTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tapas"] = [
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TapasForMaskedLM",
@@ -36,7 +41,12 @@
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_tapas"] = [
"TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFTapasForMaskedLM",
@@ -51,7 +61,12 @@
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from .tokenization_tapas import TapasTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
@@ -62,7 +77,12 @@
load_tf_weights_in_tapas,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_tapas import (
TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TFTapasForMaskedLM,
diff --git a/src/transformers/models/tapas/configuration_tapas.py b/src/transformers/models/tapas/configuration_tapas.py
index 58fb0c66b73a39..71fd5715ef57fb 100644
--- a/src/transformers/models/tapas/configuration_tapas.py
+++ b/src/transformers/models/tapas/configuration_tapas.py
@@ -27,10 +27,18 @@
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json",
- "google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json",
- "google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json",
- "google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json",
+ "google/tapas-base-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json"
+ ),
+ "google/tapas-base-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
index 88edacacfddcb0..2772a7f126ef9a 100644
--- a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py
@@ -120,8 +120,10 @@ def convert_tf_checkpoint_to_pytorch(
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained TAPAS model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained TAPAS model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py
index e34c1abb57ec33..0b65e84ca7acda 100644
--- a/src/transformers/models/tapas/modeling_tapas.py
+++ b/src/transformers/models/tapas/modeling_tapas.py
@@ -582,7 +582,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -954,7 +955,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
@@ -1068,7 +1069,7 @@ def forward(
... )
>>> labels = tokenizer(
... table=table, queries="How many movies has George Clooney played in?", return_tensors="pt"
- >>> )["input_ids"]
+ ... )["input_ids"]
>>> outputs = model(**inputs, labels=labels)
>>> logits = outputs.logits
@@ -1430,7 +1431,8 @@ def forward(
per_example_additional_loss *= large_answer_loss_mask
else:
raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the regression loss"
+ "You have to specify numeric values and numeric values scale in order to calculate the"
+ " regression loss"
)
total_loss += torch.mean(per_example_additional_loss)
diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py
index e91baaab8edb23..1875cc8009075c 100644
--- a/src/transformers/models/tapas/modeling_tf_tapas.py
+++ b/src/transformers/models/tapas/modeling_tf_tapas.py
@@ -519,8 +519,8 @@ def call(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
- "by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -1095,7 +1095,7 @@ def call(
... )
>>> labels = tokenizer(
... table=table, queries="How many movies has George Clooney played in?", return_tensors="tf"
- >>> )["input_ids"]
+ ... )["input_ids"]
>>> outputs = model(**inputs, labels=labels)
>>> logits = outputs.logits
@@ -1533,7 +1533,8 @@ def call(
per_example_additional_loss *= large_answer_loss_mask
else:
raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the regression loss"
+ "You have to specify numeric values and numeric values scale in order to calculate the"
+ " regression loss"
)
total_loss += tf.reduce_mean(per_example_additional_loss)
@@ -1723,10 +1724,13 @@ def __init__(self, outer_index, inner_index):
inner_index: IndexMap, must have the same shape as `outer_index`.
"""
if outer_index.batch_dims != inner_index.batch_dims:
- raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
+ raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")
super(ProductIndexMap, self).__init__(
- indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
+ indices=(
+ inner_index.indices
+ + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
+ ),
num_segments=inner_index.num_segments * outer_index.num_segments,
batch_dims=inner_index.batch_dims,
)
@@ -1785,7 +1789,7 @@ def flatten(index, name="segmented_flatten"):
for _ in range(index.batch_dims, index.indices.shape.rank):
offset = tf.expand_dims(offset, -1)
- indices = offset + index.indices
+ indices = tf.cast(offset, index.indices.dtype) + index.indices
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)
diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py
index 27481c35fb1436..ddb855642f4338 100644
--- a/src/transformers/models/tapas/tokenization_tapas.py
+++ b/src/transformers/models/tapas/tokenization_tapas.py
@@ -50,35 +50,83 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
# large models
- "google/tapas-large-finetuned-sqa": "https://huggingface.co/google/tapas-large-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-wtq": "https://huggingface.co/google/tapas-large-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-large-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-large-finetuned-tabfact": "https://huggingface.co/google/tapas-large-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-large-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-large-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-large-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-large-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-large-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-large-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# base models
- "google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-base-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-base-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# medium models
- "google/tapas-medium-finetuned-sqa": "https://huggingface.co/google/tapas-medium-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-wtq": "https://huggingface.co/google/tapas-medium-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-medium-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-medium-finetuned-tabfact": "https://huggingface.co/google/tapas-medium-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-medium-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-medium-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-medium-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-medium-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-medium-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-medium-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# small models
- "google/tapas-small-finetuned-sqa": "https://huggingface.co/google/tapas-small-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-wtq": "https://huggingface.co/google/tapas-small-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-small-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-small-finetuned-tabfact": "https://huggingface.co/google/tapas-small-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-small-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-small-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-small-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-small-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-small-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-small-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# tiny models
- "google/tapas-tiny-finetuned-sqa": "https://huggingface.co/google/tapas-tiny-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-wtq": "https://huggingface.co/google/tapas-tiny-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-tiny-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-tiny-finetuned-tabfact": "https://huggingface.co/google/tapas-tiny-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-tiny-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-tiny-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-tiny-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
# mini models
- "google/tapas-mini-finetuned-sqa": "https://huggingface.co/google/tapas-mini-finetuned-sqa/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-wtq": "https://huggingface.co/google/tapas-mini-finetuned-wtq/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-mini-finetuned-wikisql-supervised/resolve/main/vocab.txt",
- "google/tapas-mini-finetuned-tabfact": "https://huggingface.co/google/tapas-mini-finetuned-tabfact/resolve/main/vocab.txt",
+ "google/tapas-mini-finetuned-sqa": (
+ "https://huggingface.co/google/tapas-mini-finetuned-sqa/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-wtq": (
+ "https://huggingface.co/google/tapas-mini-finetuned-wtq/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-wikisql-supervised": (
+ "https://huggingface.co/google/tapas-mini-finetuned-wikisql-supervised/resolve/main/vocab.txt"
+ ),
+ "google/tapas-mini-finetuned-tabfact": (
+ "https://huggingface.co/google/tapas-mini-finetuned-tabfact/resolve/main/vocab.txt"
+ ),
}
}
@@ -329,8 +377,8 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
- f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
@@ -594,7 +642,8 @@ def __call__(
if not valid_query:
raise ValueError(
- "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized example). "
+ "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized"
+ " example). "
)
is_batched = isinstance(queries, (list, tuple))
@@ -1229,7 +1278,7 @@ def prepare_for_model(
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
- f"Token indices sequence length is longer than the specified maximum sequence length "
+ "Token indices sequence length is longer than the specified maximum sequence length "
f"for this model ({len(encoded_inputs['input_ids'])} > {self.model_max_length}). Running this "
"sequence through the model will result in indexing errors."
)
diff --git a/src/transformers/models/tapex/__init__.py b/src/transformers/models/tapex/__init__.py
index 36c5938d23c9bd..3b13bed2ca1025 100644
--- a/src/transformers/models/tapex/__init__.py
+++ b/src/transformers/models/tapex/__init__.py
@@ -21,9 +21,7 @@
from ...file_utils import _LazyModule
-_import_structure = {
- "tokenization_tapex": ["TapexTokenizer"],
-}
+_import_structure = {"tokenization_tapex": ["TapexTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/tapex/tokenization_tapex.py b/src/transformers/models/tapex/tokenization_tapex.py
index 0b5c1241415ae0..ea1dc0dcc49246 100644
--- a/src/transformers/models/tapex/tokenization_tapex.py
+++ b/src/transformers/models/tapex/tokenization_tapex.py
@@ -503,7 +503,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
diff --git a/src/transformers/models/trajectory_transformer/__init__.py b/src/transformers/models/trajectory_transformer/__init__.py
new file mode 100644
index 00000000000000..0b8a6f2c5892d7
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/__init__.py
@@ -0,0 +1,68 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_trajectory_transformer": [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "TrajectoryTransformerConfig",
+ ],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_trajectory_transformer"] = [
+ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ "load_tf_weights_in_trajectory_transformer",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ TrajectoryTransformerConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TrajectoryTransformerModel,
+ TrajectoryTransformerPreTrainedModel,
+ load_tf_weights_in_trajectory_transformer,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
new file mode 100644
index 00000000000000..537a467c701667
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
@@ -0,0 +1,167 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TrajectoryTransformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2": (
+ "https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2/resolve/main/config.json"
+ ),
+ # See all TrajectoryTransformer models at https://huggingface.co/models?filter=trajectory_transformer
+}
+
+
+class TrajectoryTransformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TrajectoryTransformerModel`]. It is used to
+ instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ TrajectoryTransformer
+ [CarlCochet/trajectory-transformer-halfcheetah-medium-v2](https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 100):
+ Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be
+ represented by the `trajectories` passed when calling [`TrajectoryTransformerModel`]
+ batch_size (`int`, *optional*, defaults to 256):
+ Size of the batch of trajectories passed to the model.
+ action_weight (`int`, *optional*, defaults to 5):
+ Weight of the action in the loss function
+ reward_weight (`int`, *optional*, defaults to 1):
+ Weight of the reward in the loss function
+ value_weight (`int`, *optional*, defaults to 1):
+ Weight of the value in the loss function
+ block_size (`int`, *optional*, defaults to 249):
+ Size of the blocks in the trajectory transformer.
+ action_dim (`int`, *optional*, defaults to 6):
+ Dimension of the action space.
+ observation_dim (`int`, *optional*, defaults to 17):
+ Dimension of the observation space.
+ transition_dim (`int`, *optional*, defaults to 25):
+ Dimension of the transition space.
+ n_layer (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_embd (`int`, *optional*, defaults to 128):
+ Dimensionality of the embeddings and hidden states.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`TrajectoryTransformerModel`]
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ kaiming_initializer_range (`float, *optional*, defaults to 1):
+ A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ Example:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerModel, TrajectoryTransformerConfig
+
+ >>> # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> configuration = TrajectoryTransformerConfig()
+
+ >>> # Initializing a model from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> model = TrajectoryTransformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "trajectory_transformer"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=100,
+ batch_size=256,
+ action_weight=5,
+ reward_weight=1,
+ value_weight=1,
+ block_size=249,
+ action_dim=6,
+ observation_dim=17,
+ transition_dim=25,
+ n_layer=4,
+ n_head=4,
+ n_embd=128,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ resid_pdrop=0.1,
+ learning_rate=0.0006,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ kaiming_initializer_range=1,
+ use_cache=True,
+ is_encoder_decoder=False,
+ pad_token_id=1,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ **kwargs
+ ):
+ self.vocab_size = vocab_size
+ self.batch_size = batch_size
+ self.action_weight = action_weight
+ self.reward_weight = reward_weight
+ self.value_weight = value_weight
+ self.max_position_embeddings = max_position_embeddings
+ self.block_size = block_size
+ self.action_dim = action_dim
+ self.observation_dim = observation_dim
+ self.transition_dim = transition_dim
+ self.learning_rate = learning_rate
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_embd = n_embd
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.resid_pdrop = resid_pdrop
+ self.initializer_range = initializer_range
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+ self.kaiming_initializer_range = kaiming_initializer_range
+ self.use_cache = use_cache
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 00000000000000..14e6556e07b7a1
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,70 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TrajectoryTransformer pytorch checkpoint conversion"""
+
+import torch
+
+import trajectory.utils as utils
+from transformers import TrajectoryTransformerModel
+
+
+class Parser(utils.Parser):
+ dataset: str = "halfcheetah-medium-expert-v2"
+ config: str = "config.offline"
+
+
+def convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(logbase, dataset, loadpath, epoch, device):
+ """Converting Sequential blocks to ModuleList"""
+
+ gpt, gpt_epoch = utils.load_model(logbase, dataset, loadpath, epoch=epoch, device=device)
+ trajectory_transformer = TrajectoryTransformerModel(gpt.config)
+
+ trajectory_transformer.tok_emb.load_state_dict(gpt.tok_emb.state_dict())
+ trajectory_transformer.pos_emb = gpt.pos_emb
+ trajectory_transformer.drop.load_state_dict(gpt.drop.state_dict())
+ trajectory_transformer.ln_f.load_state_dict(gpt.ln_f.state_dict())
+ trajectory_transformer.head.load_state_dict(gpt.head.state_dict())
+
+ for i, block in enumerate(gpt.blocks):
+ trajectory_transformer.blocks[i].ln1.load_state_dict(gpt.blocks[i].ln1.state_dict())
+ trajectory_transformer.blocks[i].ln2.load_state_dict(gpt.blocks[i].ln2.state_dict())
+ trajectory_transformer.blocks[i].attn.load_state_dict(gpt.blocks[i].attn.state_dict())
+
+ trajectory_transformer.blocks[i].l1.load_state_dict(gpt.blocks[i].mlp[0].state_dict())
+ trajectory_transformer.blocks[i].act.load_state_dict(gpt.blocks[i].mlp[1].state_dict())
+ trajectory_transformer.blocks[i].l2.load_state_dict(gpt.blocks[i].mlp[2].state_dict())
+ trajectory_transformer.blocks[i].drop.load_state_dict(gpt.blocks[i].mlp[3].state_dict())
+
+ torch.save(trajectory_transformer.state_dict(), "pytorch_model.bin")
+
+
+if __name__ == "__main__":
+ """
+ To run this script you will need to install the original repository to run the original model. You can find it
+ here: https://github.com/jannerm/trajectory-transformer From this repository code you can also download the
+ original pytorch checkpoints.
+
+ Run with the command:
+
+ ```sh
+ >>> python convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py --dataset
+ ... --gpt_loadpath
+ ```
+ """
+
+ args = Parser().parse_args("plan")
+ convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(
+ args.logbase, args.dataset, args.gpt_loadpath, args.gpt_epoch, args.device
+ )
diff --git a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py
new file mode 100644
index 00000000000000..f647a13afead44
--- /dev/null
+++ b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py
@@ -0,0 +1,617 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch TrajectoryTransformer model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import functional as F
+
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_trajectory_transformer import TrajectoryTransformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+_CONFIG_FOR_DOC = "TrajectoryTransformerConfig"
+
+TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2",
+ # See all TrajectoryTransformer models at https://huggingface.co/models?filter=trajectory_transformer
+]
+
+
+def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+@dataclass
+class TrajectoryTransformerOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
+ sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the
+ attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. GPT2Attentions weights after the attention softmax, used to compute the weighted average
+ in the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TrajectoryTransformerConfig
+ load_tf_weights = load_tf_weights_in_trajectory_transformer
+ base_model_prefix = "trajectory_transformer"
+ main_input_name = "trajectories"
+ supports_gradient_checkpointing = True
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, TrajectoryTransformerModel):
+ module.gradient_checkpointing = value
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, EinLinear):
+ for i in range(module.n_models):
+ nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range)
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i])
+ bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range
+ nn.init.uniform_(module.bias[i], -bound, bound)
+
+
+TRAJECTORY_TRANSFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`TrajectoryTransformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ trajectories (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Batch of trajectories, where a trajectory is a sequence of states, actions and rewards.
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`, *optional*):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ targets (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Desired targets used to compute the loss.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class EinLinear(nn.Module):
+ def __init__(self, n_models, in_features, out_features, bias):
+ super().__init__()
+ self.n_models = n_models
+ self.out_features = out_features
+ self.in_features = in_features
+ self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(n_models, out_features))
+ else:
+ self.register_parameter("bias", None)
+
+ def reset_parameters(self):
+ for i in range(self.n_models):
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias[i], -bound, bound)
+
+ def forward(self, input):
+ """
+ Args:
+ input (`torch.FloatTensor` of shape `(B, n_models, input_dim)`):
+ The input to the layer.
+ """
+ # [ batch_size x n_models x output_dim ]
+ output = torch.einsum("eoi,bei->beo", self.weight, input)
+ if self.bias is not None:
+ raise RuntimeError()
+ return output
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ if config.n_embd % config.n_head != 0:
+ raise ValueError(f"n_head ({config.n_head}) should be a divisor of n_embd ({config.n_embd})")
+
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "mask",
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
+ 1, 1, config.block_size, config.block_size
+ ),
+ )
+
+ # mask previous value estimates
+ joined_dim = config.observation_dim + config.action_dim + 2
+ self.mask.squeeze()[:, joined_dim - 1 :: joined_dim] = 0
+
+ self.n_head = config.n_head
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ batch_size, sequence_length, embedding_dim = hidden_states.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ # [ batch_size x n_heads x sequence_length x head_dim ]
+ key = (
+ self.key(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ query = (
+ self.query(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ value = (
+ self.value(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ # causal self-attention
+ # [ batch_size x n_heads x sequence_length x sequence_length ]
+ attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1)))
+ attn_weights = attn_weights.masked_fill(
+ self.mask[:, :, :sequence_length, :sequence_length] == 0, float("-inf")
+ )
+ attn_weights = F.softmax(attn_weights, dim=-1)
+ self._attn_map = attn_weights.clone()
+ attn_weights = self.attn_drop(attn_weights)
+
+ output = torch.matmul(attn_weights, value)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # re-assemble all head outputs side by side
+ output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, embedding_dim)
+
+ # output projection
+ output = self.resid_drop(self.proj(output))
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Block(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+
+ # MLP
+ self.l1 = nn.Linear(config.n_embd, 4 * config.n_embd)
+ self.act = nn.GELU()
+ self.l2 = nn.Linear(4 * config.n_embd, config.n_embd)
+ self.drop = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln1(hidden_states)
+
+ attn_outputs = self.attn(
+ hidden_states, layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions
+ )
+ attn_output = attn_outputs[0]
+ outputs = attn_outputs[1:]
+ hidden_states = attn_output + residual
+
+ residual = hidden_states
+ hidden_states = self.ln2(hidden_states)
+ hidden_states = self.l1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.l2(hidden_states)
+ hidden_states = residual + self.drop(hidden_states)
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs
+
+
+@add_start_docstrings(
+ "The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top.",
+ TRAJECTORY_TRANSFORMER_START_DOCSTRING,
+)
+class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
+ """the full GPT language model, with a context size of block_size"""
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ # input embedding stem (+1 for stop token)
+ self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd)
+
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False)
+
+ self.vocab_size = config.vocab_size
+ self.stop_token = config.vocab_size * config.transition_dim
+ self.block_size = config.block_size
+
+ self.observation_dim = config.observation_dim
+ self.action_dim = config.action_dim
+ self.transition_dim = config.transition_dim
+ self.embedding_dim = config.n_embd
+
+ self.action_weight = config.action_weight
+ self.reward_weight = config.reward_weight
+ self.value_weight = config.value_weight
+
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ def get_block_size(self):
+ return self.block_size
+
+ def offset_tokens(self, trajectories):
+ _, sequence_length = trajectories.shape
+
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+
+ offsets = torch.arange(self.transition_dim) * self.vocab_size
+ offsets = offsets.repeat(n_states).to(trajectories.device)
+
+ offset_trajectories = trajectories + offsets[:sequence_length]
+ offset_trajectories[trajectories == self.vocab_size] = self.stop_token
+ return offset_trajectories
+
+ def pad_to_full_observation(self, hidden_states):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ n_pad = (self.transition_dim - sequence_length % self.transition_dim) % self.transition_dim
+ padding = torch.zeros(batch_size, n_pad, self.embedding_dim, device=hidden_states.device)
+
+ # [ batch_size x padded_sequence_length' x embedding_dim ]
+ hidden_states_pad = torch.cat([hidden_states, padding], dim=1)
+ hidden_states_pad = hidden_states_pad.view(-1, self.transition_dim, self.embedding_dim)
+
+ return hidden_states_pad, n_pad
+
+ @add_start_docstrings_to_model_forward(
+ TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(output_type=TrajectoryTransformerOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ trajectories: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ targets: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerModel
+ >>> import torch
+
+ >>> model = TrajectoryTransformerModel.from_pretrained(
+ ... "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+ ... )
+ >>> model.to(device)
+ >>> model.eval()
+
+ >>> observations_dim, action_dim, batch_size = 17, 6, 256
+ >>> seq_length = observations_dim + action_dim + 1
+
+ >>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(
+ ... device
+ ... )
+ >>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device)
+
+ >>> outputs = model(
+ ... trajectories,
+ ... targets=targets,
+ ... use_cache=True,
+ ... output_attentions=True,
+ ... output_hidden_states=True,
+ ... return_dict=True,
+ ... )
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.blocks))
+
+ batch_size, sequence_length = trajectories.size()
+
+ if sequence_length > self.block_size:
+ raise ValueError("Cannot forward, model block size is exhausted.")
+
+ offset_trajectories = self.offset_tokens(trajectories)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # forward the GPT model
+ token_embeddings = self.tok_emb(offset_trajectories) # each index maps to a (learnable) vector
+ position_embeddings = self.pos_emb[:, :sequence_length, :] # each position maps to a (learnable) vector
+
+ hidden_states = self.drop(token_embeddings + position_embeddings)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ layer_past,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ outputs = block(hidden_states, layer_past, use_cache, output_attentions)
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # [ batch_size x sequence_length x embedding_dim ]
+ hidden_state = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states_pad, n_pad = self.pad_to_full_observation(hidden_state)
+
+ logits = self.head(hidden_states_pad)
+ logits = logits.reshape(batch_size, sequence_length + n_pad, self.vocab_size + 1)
+ logits = logits[:, :sequence_length]
+
+ # if we are given some desired targets also calculate the loss
+ if targets is not None:
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction="none")
+ if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
+ # make weights
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+ weights = torch.cat(
+ [
+ torch.ones(self.observation_dim, device=trajectories.device),
+ torch.ones(self.action_dim, device=trajectories.device) * self.action_weight,
+ torch.ones(1, device=trajectories.device) * self.reward_weight,
+ torch.ones(1, device=trajectories.device) * self.value_weight,
+ ]
+ )
+ weights = weights.repeat(n_states)
+ weights = weights[1:].repeat(batch_size, 1)
+ loss = loss * weights.view(-1)
+ loss = (loss * attention_mask.view(-1)).mean()
+ else:
+ loss = None
+
+ if not return_dict:
+ return tuple(v for v in [loss, logits, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TrajectoryTransformerOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
diff --git a/src/transformers/models/transfo_xl/__init__.py b/src/transformers/models/transfo_xl/__init__.py
index ed01124a490590..672ad9afc5274f 100644
--- a/src/transformers/models/transfo_xl/__init__.py
+++ b/src/transformers/models/transfo_xl/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
@@ -26,7 +26,12 @@
"tokenization_transfo_xl": ["TransfoXLCorpus", "TransfoXLTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_transfo_xl"] = [
"TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"AdaptiveEmbedding",
@@ -37,7 +42,12 @@
"load_tf_weights_in_transfo_xl",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_transfo_xl"] = [
"TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFAdaptiveEmbedding",
@@ -53,7 +63,12 @@
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,
@@ -64,7 +79,12 @@
load_tf_weights_in_transfo_xl,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
diff --git a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
index abde04bd43c721..646c8a2342fc3a 100755
--- a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
@@ -101,8 +101,10 @@ def convert_transfo_xl_checkpoint_to_pytorch(
"--transfo_xl_config_file",
default="",
type=str,
- help="An optional config json file corresponding to the pre-trained BERT model. \n"
- "This specifies the model architecture.",
+ help=(
+ "An optional config json file corresponding to the pre-trained BERT model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--transfo_xl_dataset_file",
diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
index 29753738839c8a..66467350f14218 100644
--- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
@@ -935,9 +935,10 @@ def __init__(self, config):
super().__init__(config)
self.transformer = TFTransfoXLMainLayer(config, name="transformer")
self.sample_softmax = config.sample_softmax
- assert (
- self.sample_softmax <= 0
- ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
self.crit = TFAdaptiveSoftmaxMask(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
@@ -1126,7 +1127,7 @@ def call(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
index af95f348ec28f7..dcfa84d0f94b69 100644
--- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
+++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
@@ -111,7 +111,7 @@ def _logit(x, W, b, proj=None):
@staticmethod
def _gather_logprob(logprob, target):
lp_size = shape_list(logprob)
- r = tf.range(lp_size[0])
+ r = tf.range(lp_size[0], dtype=target.dtype)
idx = tf.stack([r, target], 1)
return tf.gather_nd(logprob, idx)
diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py
index f566262ff29404..7986aa7af9b61b 100644
--- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py
@@ -881,14 +881,14 @@ def _update_mems(self, hids, mems, mlen, qlen):
)
def forward(
self,
- input_ids=None,
- mems=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ mems: Optional[List[torch.FloatTensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TransfoXLModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1020,13 +1020,15 @@ def __init__(self, config):
if not self.trainer_compatible:
warnings.warn(
"The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order"
- "to use that updated output, please specify `trainer_compatible=True` as your configuration attribute.",
+ "to use that updated output, please specify `trainer_compatible=True` as your configuration"
+ " attribute.",
DeprecationWarning,
)
- assert (
- self.sample_softmax <= 0
- ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
self.crit = ProjectedAdaptiveLogSoftmax(
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
@@ -1071,15 +1073,15 @@ def init_mems(self, bsz):
)
def forward(
self,
- input_ids=None,
- mems=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
+ input_ids: Optional[torch.LongTensor] = None,
+ mems: Optional[List[torch.FloatTensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TransfoXLLMHeadModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -1196,7 +1198,7 @@ def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[tor
TRANSFO_XL_START_DOCSTRING,
)
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
@@ -1215,11 +1217,11 @@ def __init__(self, config):
)
def forward(
self,
- input_ids: Optional[torch.Tensor] = None,
+ input_ids: Optional[torch.LongTensor] = None,
mems: Optional[List[torch.FloatTensor]] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1261,7 +1263,7 @@ def forward(
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
index b25dc2d707d6d9..e25ba2cd476a0b 100644
--- a/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
+++ b/src/transformers/models/transfo_xl/modeling_transfo_xl_utilities.py
@@ -102,7 +102,7 @@ def forward(self, hidden, labels=None, keep_order=False):
hidden = hidden.view(-1, hidden.size(-1))
labels = labels.view(-1)
if hidden.size(0) != labels.size(0):
- raise RuntimeError("Input and labels should have the same size " "in the batch dimension.")
+ raise RuntimeError("Input and labels should have the same size in the batch dimension.")
else:
hidden = hidden.view(-1, hidden.size(-1))
diff --git a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
index 115cd4fdcfca90..cc72925bb03358 100644
--- a/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
+++ b/src/transformers/models/transfo_xl/tokenization_transfo_xl.py
@@ -680,10 +680,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
- f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list "
- f"({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. "
- f"We assumed '{pretrained_model_name_or_path}' was a path or url but couldn't find files {corpus_file} "
- "at this path or url."
+ f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
+ f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
+ f" was a path or url but couldn't find files {corpus_file} at this path or url."
)
return None
if resolved_corpus_file == corpus_file:
diff --git a/src/transformers/models/trocr/__init__.py b/src/transformers/models/trocr/__init__.py
index 5f9f462e183992..8e18eaeb4069e9 100644
--- a/src/transformers/models/trocr/__init__.py
+++ b/src/transformers/models/trocr/__init__.py
@@ -17,19 +17,27 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_torch_available,
+)
_import_structure = {
- "configuration_trocr": [
- "TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "TrOCRConfig",
- ],
+ "configuration_trocr": ["TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP", "TrOCRConfig"],
"processing_trocr": ["TrOCRProcessor"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_trocr"] = [
"TROCR_PRETRAINED_MODEL_ARCHIVE_LIST",
"TrOCRForCausalLM",
@@ -41,7 +49,12 @@
from .configuration_trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig
from .processing_trocr import TrOCRProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRForCausalLM, TrOCRPreTrainedModel
else:
diff --git a/src/transformers/models/trocr/configuration_trocr.py b/src/transformers/models/trocr/configuration_trocr.py
index fc878da26d5121..a635e6b9b09729 100644
--- a/src/transformers/models/trocr/configuration_trocr.py
+++ b/src/transformers/models/trocr/configuration_trocr.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/trocr-base-handwritten": "https://huggingface.co/microsoft/trocr-base-handwritten/resolve/main/config.json",
+ "microsoft/trocr-base-handwritten": (
+ "https://huggingface.co/microsoft/trocr-base-handwritten/resolve/main/config.json"
+ ),
# See all TrOCR models at https://huggingface.co/models?filter=trocr
}
diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py
index 75e015f98848dc..3a960bb86f3cbc 100644
--- a/src/transformers/models/trocr/modeling_trocr.py
+++ b/src/transformers/models/trocr/modeling_trocr.py
@@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -72,7 +72,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
@@ -182,7 +182,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
if not (self.head_dim * num_heads == self.embed_dim):
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
@@ -254,7 +255,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -270,7 +272,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -291,7 +294,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -520,7 +524,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -667,7 +671,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -683,7 +688,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
+ " False`..."
)
use_cache = False
@@ -769,7 +775,8 @@ def forward(self, *args, **kwargs):
@add_start_docstrings(
- "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and [`VisionEncoderDecoder`].",
+ "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
+ " [`VisionEncoderDecoder`].",
TROCR_START_DOCSTRING,
)
class TrOCRForCausalLM(TrOCRPreTrainedModel):
diff --git a/src/transformers/models/unispeech/__init__.py b/src/transformers/models/unispeech/__init__.py
index 537b125ec0ef86..3713e7d8a11ceb 100644
--- a/src/transformers/models/unispeech/__init__.py
+++ b/src/transformers/models/unispeech/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig"],
-}
+_import_structure = {"configuration_unispeech": ["UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_unispeech"] = [
"UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechForCTC",
@@ -37,7 +46,12 @@
if TYPE_CHECKING:
from .configuration_unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_unispeech import (
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechForCTC,
diff --git a/src/transformers/models/unispeech/configuration_unispeech.py b/src/transformers/models/unispeech/configuration_unispeech.py
index 85b99859209459..733e68e627bdd9 100644
--- a/src/transformers/models/unispeech/configuration_unispeech.py
+++ b/src/transformers/models/unispeech/configuration_unispeech.py
@@ -24,7 +24,9 @@
logger = logging.get_logger(__name__)
UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/unispeech-large-1500h-cv": "https://huggingface.co/microsoft/unispeech-large-1500h-cv/resolve/main/config.json",
+ "microsoft/unispeech-large-1500h-cv": (
+ "https://huggingface.co/microsoft/unispeech-large-1500h-cv/resolve/main/config.json"
+ ),
# See all UniSpeech models at https://huggingface.co/models?filter=unispeech
}
@@ -78,13 +80,13 @@ class UniSpeechConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
@@ -261,10 +263,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
index 83f051627cc39c..bf729309515eac 100644
--- a/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
@@ -84,9 +84,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type, is_finetuned
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -154,28 +155,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index 61359bf032f0c2..8c17708fb72618 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
+from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
@@ -71,35 +71,6 @@
]
-@dataclass
-class UniSpeechBaseModelOutput(ModelOutput):
- """
- Output type of [`UniSpeechBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class UniSpeechForPreTrainingOutput(ModelOutput):
"""
@@ -554,7 +525,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -570,7 +542,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -591,7 +564,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -723,7 +697,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -811,7 +786,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -888,7 +864,8 @@ def __init__(self, config):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
+ f" {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1154,7 +1131,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=UniSpeechBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1167,7 +1144,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, UniSpeechBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1199,7 +1176,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return UniSpeechBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
diff --git a/src/transformers/models/unispeech_sat/__init__.py b/src/transformers/models/unispeech_sat/__init__.py
index 75a7397ff7e46e..d4a5e179539a39 100644
--- a/src/transformers/models/unispeech_sat/__init__.py
+++ b/src/transformers/models/unispeech_sat/__init__.py
@@ -17,14 +17,25 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
"configuration_unispeech_sat": ["UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "UniSpeechSatConfig"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_unispeech_sat"] = [
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"UniSpeechSatForAudioFrameClassification",
@@ -39,7 +50,12 @@
if TYPE_CHECKING:
from .configuration_unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_unispeech_sat import (
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
UniSpeechSatForAudioFrameClassification,
diff --git a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
index b88d9cf91fc9fd..bc8663587d96d8 100644
--- a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
@@ -24,7 +24,9 @@
logger = logging.get_logger(__name__)
UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/unispeech-sat-base-100h-libri-ft": "https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft/resolve/main/config.json",
+ "microsoft/unispeech-sat-base-100h-libri-ft": (
+ "https://huggingface.co/microsoft/unispeech-sat-base-100h-libri-ft/resolve/main/config.json"
+ ),
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
}
@@ -79,13 +81,13 @@ class UniSpeechSatConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
@@ -157,13 +159,13 @@ class UniSpeechSatConfig(PretrainedConfig):
instance of [`UniSpeechSatForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -273,10 +275,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
index 78a541d7ed49a5..93750b64cc3a2d 100644
--- a/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
@@ -72,7 +72,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
if hf_shape != value.shape:
raise ValueError(
- f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
)
if weight_type == "weight":
@@ -146,14 +147,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
@@ -161,14 +164,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 1812cd65237ee9..5c80c693e8aebe 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -27,7 +27,14 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
@@ -77,35 +84,6 @@
]
-@dataclass
-class UniSpeechSatBaseModelOutput(ModelOutput):
- """
- Output type of [`UniSpeechSatBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class UniSpeechSatForPreTrainingOutput(ModelOutput):
"""
@@ -143,38 +121,6 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@@ -593,7 +539,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -609,7 +556,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -630,7 +578,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -762,7 +711,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -850,7 +800,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -927,7 +878,8 @@ def __init__(self, config):
if config.codevector_dim % self.num_groups != 0:
raise ValueError(
- f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ f"`config.codevector_dim {config.codevector_dim} must be divisible by `config.num_codevector_groups`"
+ f" {self.num_groups} for concatenation"
)
# storage for codebook variables (codewords)
@@ -1194,7 +1146,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=UniSpeechSatBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1207,7 +1159,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, UniSpeechSatBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1239,7 +1191,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return UniSpeechSatBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1651,13 +1603,15 @@ def __init__(self, config):
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
- "Audio frame classification does not support the use of UniSpeechSat adapters (config.add_adapter=True)"
+ "Audio frame classification does not support the use of UniSpeechSat adapters"
+ " (config.add_adapter=True)"
)
self.unispeech_sat = UniSpeechSatModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1701,6 +1655,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1733,12 +1688,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/van/__init__.py b/src/transformers/models/van/__init__.py
index 73e2752b1f2e2f..44c88f0448c30b 100644
--- a/src/transformers/models/van/__init__.py
+++ b/src/transformers/models/van/__init__.py
@@ -18,15 +18,18 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
-}
+_import_structure = {"configuration_van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_van"] = [
"VAN_PRETRAINED_MODEL_ARCHIVE_LIST",
"VanForImageClassification",
@@ -37,7 +40,12 @@
if TYPE_CHECKING:
from .configuration_van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_van import (
VAN_PRETRAINED_MODEL_ARCHIVE_LIST,
VanForImageClassification,
diff --git a/src/transformers/models/van/configuration_van.py b/src/transformers/models/van/configuration_van.py
index 6d4becdf552bb2..47d5a9b6c11aa1 100644
--- a/src/transformers/models/van/configuration_van.py
+++ b/src/transformers/models/van/configuration_van.py
@@ -21,7 +21,9 @@
logger = logging.get_logger(__name__)
VAN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "Visual-Attention-Network/van-base": "https://huggingface.co/Visual-Attention-Network/van-base/blob/main/config.json",
+ "Visual-Attention-Network/van-base": (
+ "https://huggingface.co/Visual-Attention-Network/van-base/blob/main/config.json"
+ ),
}
diff --git a/src/transformers/models/van/convert_van_to_pytorch.py b/src/transformers/models/van/convert_van_to_pytorch.py
index cb79c82c5c9e6e..e2c0c95e64502b 100644
--- a/src/transformers/models/van/convert_van_to_pytorch.py
+++ b/src/transformers/models/van/convert_van_to_pytorch.py
@@ -85,7 +85,8 @@ def __call__(self, x: Tensor):
if len(dest_traced) != len(src_traced):
raise Exception(
- f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
@@ -208,10 +209,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
}
names_to_original_checkpoints = {
- "van-tiny": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar",
- "van-small": "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar",
- "van-base": "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar",
- "van-large": "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar",
+ "van-tiny": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar"
+ ),
+ "van-small": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar"
+ ),
+ "van-base": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar"
+ ),
+ "van-large": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar"
+ ),
}
if model_name:
@@ -242,7 +251,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--model-name",
default=None,
type=str,
- help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: van-tiny/small/base/large. If `None`, all of them will the converted.",
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported resnet* architecture,"
+ " currently: van-tiny/small/base/large. If `None`, all of them will the converted."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
@@ -255,7 +267,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
"--van_dir",
required=True,
type=Path,
- help="A path to VAN's original implementation directory. You can download from here: https://github.com/Visual-Attention-Network/VAN-Classification",
+ help=(
+ "A path to VAN's original implementation directory. You can download from here:"
+ " https://github.com/Visual-Attention-Network/VAN-Classification"
+ ),
)
parser.add_argument(
"--push_to_hub",
diff --git a/src/transformers/models/van/modeling_van.py b/src/transformers/models/van/modeling_van.py
index 3ea59cbae8d359..b94f9969ba9d28 100644
--- a/src/transformers/models/van/modeling_van.py
+++ b/src/transformers/models/van/modeling_van.py
@@ -16,6 +16,7 @@
import math
from collections import OrderedDict
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -81,11 +82,11 @@ def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
-class VanOverlappingPatchEmbedder(nn.Sequential):
+class VanOverlappingPatchEmbedder(nn.Module):
"""
Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
@@ -99,8 +100,13 @@ def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stri
)
self.normalization = nn.BatchNorm2d(hidden_size)
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
-class VanMlpLayer(nn.Sequential):
+class VanMlpLayer(nn.Module):
"""
MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
Transformer](https://arxiv.org/abs/2106.13797).
@@ -122,8 +128,17 @@ def __init__(
self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
self.dropout2 = nn.Dropout(dropout_rate)
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.in_dense(hidden_state)
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.dropout1(hidden_state)
+ hidden_state = self.out_dense(hidden_state)
+ hidden_state = self.dropout2(hidden_state)
+ return hidden_state
+
-class VanLargeKernelAttention(nn.Sequential):
+class VanLargeKernelAttention(nn.Module):
"""
Basic Large Kernel Attention (LKA).
"""
@@ -136,6 +151,12 @@ def __init__(self, hidden_size: int):
)
self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.depth_wise_dilated(hidden_state)
+ hidden_state = self.point_wise(hidden_state)
+ return hidden_state
+
class VanLargeKernelAttentionLayer(nn.Module):
"""
@@ -146,7 +167,7 @@ def __init__(self, hidden_size: int):
super().__init__()
self.attention = VanLargeKernelAttention(hidden_size)
- def forward(self, hidden_state):
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
attention = self.attention(hidden_state)
attended = hidden_state * attention
return attended
@@ -171,7 +192,7 @@ def __init__(self, hidden_size: int, hidden_act: str = "gelu"):
self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
- def forward(self, hidden_state):
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state
hidden_state = self.pre_projection(hidden_state)
hidden_state = self.attention_layer(hidden_state)
@@ -189,7 +210,7 @@ def __init__(self, hidden_size: int, initial_value: float = 1e-2):
super().__init__()
self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
- def forward(self, hidden_state):
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
# unsqueezing for broadcasting
hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
return hidden_state
@@ -218,7 +239,7 @@ def __init__(
)
self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
- def forward(self, hidden_state):
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state
# attention
hidden_state = self.pre_normomalization(hidden_state)
@@ -269,7 +290,7 @@ def __init__(
)
self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_state):
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.embeddings(hidden_state)
hidden_state = self.layers(hidden_state)
# rearrange b c h w -> b (h w) c
@@ -316,7 +337,12 @@ def __init__(self, config: VanConfig):
)
)
- def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for _, stage_module in enumerate(self.stages):
@@ -389,7 +415,8 @@ def _set_gradient_checkpointing(self, module, value=False):
@add_start_docstrings(
- "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding layer.",
+ "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding"
+ " layer.",
VAN_START_DOCSTRING,
)
class VanModel(VanPreTrainedModel):
@@ -411,7 +438,12 @@ def __init__(self, config):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
- def forward(self, pixel_values, output_hidden_states=None, return_dict=None):
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
@@ -463,7 +495,13 @@ def __init__(self, config):
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
- def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
diff --git a/src/transformers/models/vilt/__init__.py b/src/transformers/models/vilt/__init__.py
index 7aa27b98deca1d..3861b081be2fb3 100644
--- a/src/transformers/models/vilt/__init__.py
+++ b/src/transformers/models/vilt/__init__.py
@@ -18,18 +18,26 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_torch_available, is_vision_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
-_import_structure = {
- "configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"],
-}
+_import_structure = {"configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"]}
-if is_vision_available():
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_vilt"] = ["ViltFeatureExtractor"]
_import_structure["processing_vilt"] = ["ViltProcessor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vilt"] = [
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViltForImageAndTextRetrieval",
@@ -45,11 +53,21 @@
if TYPE_CHECKING:
from .configuration_vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_vilt import ViltFeatureExtractor
from .processing_vilt import ViltProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vilt import (
VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViltForImageAndTextRetrieval,
diff --git a/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
index 9de026ebec86fc..3a186e1d2d918a 100644
--- a/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
+++ b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
@@ -231,7 +231,10 @@ def convert_vilt_checkpoint(checkpoint_url, pytorch_dump_folder_path):
if nlvr_model:
image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
- text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ text = (
+ "The left image contains twice the number of dogs as the right image, and at least two dogs in total are"
+ " standing."
+ )
encoding_1 = processor(image1, text, return_tensors="pt")
encoding_2 = processor(image2, text, return_tensors="pt")
outputs = model(
diff --git a/src/transformers/models/vilt/feature_extraction_vilt.py b/src/transformers/models/vilt/feature_extraction_vilt.py
index 7fdd138750ac82..0c64c10959bd8d 100644
--- a/src/transformers/models/vilt/feature_extraction_vilt.py
+++ b/src/transformers/models/vilt/feature_extraction_vilt.py
@@ -33,6 +33,7 @@
if is_torch_available():
import torch
+
logger = logging.get_logger(__name__)
diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py
index f29057addecc03..174799318a80d8 100755
--- a/src/transformers/models/vilt/modeling_vilt.py
+++ b/src/transformers/models/vilt/modeling_vilt.py
@@ -41,6 +41,12 @@
logger = logging.get_logger(__name__)
+if torch.__version__ < (1, 10, 0):
+ logger.warning(
+ f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use "
+ "ViltModel. Please upgrade torch."
+ )
+
_CONFIG_FOR_DOC = "ViltConfig"
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm"
@@ -843,7 +849,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_outputs = self.encoder(
embedding_output,
diff --git a/src/transformers/models/vision_encoder_decoder/__init__.py b/src/transformers/models/vision_encoder_decoder/__init__.py
index 0757f15ec8197a..5d501b8feb83c4 100644
--- a/src/transformers/models/vision_encoder_decoder/__init__.py
+++ b/src/transformers/models/vision_encoder_decoder/__init__.py
@@ -18,32 +18,66 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
-}
+_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else:
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
index e0478f1e13a5f6..7042b2548deb76 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
@@ -301,10 +301,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -832,10 +832,9 @@ def from_encoder_decoder_pretrained(
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
index edc2973a073489..ba65525ae00b12 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
@@ -43,10 +43,10 @@
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
DEPRECATION_WARNING = (
- "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
- "encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
- "a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
- "need to pass them yourself anymore."
+ "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
)
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
@@ -202,10 +202,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -222,11 +222,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -326,7 +328,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
>>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
- >>> ).sequences
+ ... ).sequences
>>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
>>> preds = [pred.strip() for pred in preds]
@@ -337,10 +339,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pt = kwargs.pop("from_pt", False)
if from_pt:
raise ValueError(
- "Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. "
- "Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, "
- "create the encoder and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. "
- "Check `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
+ "Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. Use"
+ " a tensorflow checkpoint instead. If only the pytorch checkpoints are available, create the encoder"
+ " and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. Check"
+ " `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@@ -469,10 +471,9 @@ def from_encoder_decoder_pretrained(
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
index 37072270a567d8..d2c4ae6b18cf32 100644
--- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
+++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
@@ -173,10 +173,10 @@ def __init__(
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
- "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
- "it has to be equal to the encoder's `hidden_size`. "
- f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
- f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
)
# initialize with config
@@ -195,11 +195,13 @@ def __init__(
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
@@ -369,10 +371,9 @@ def from_encoder_decoder_pretrained(
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
- f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
- f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
- "cross attention layers."
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
@@ -546,8 +547,8 @@ def prepare_inputs_for_generation(
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
- "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported."
- "Please use the respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
+ "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
diff --git a/src/transformers/models/vision_text_dual_encoder/__init__.py b/src/transformers/models/vision_text_dual_encoder/__init__.py
index 4e705cd03721ee..89aa78c831129f 100644
--- a/src/transformers/models/vision_text_dual_encoder/__init__.py
+++ b/src/transformers/models/vision_text_dual_encoder/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_flax_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available
_import_structure = {
@@ -27,11 +27,21 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]
@@ -39,10 +49,20 @@
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
index 4cf6c59882aa70..aac1b0e8e93d2f 100644
--- a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
+++ b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
@@ -536,9 +536,9 @@ def from_vision_text_pretrained(
# the projection layers are always newly initialized when loading the model
# using pre-trained vision and text model.
logger.warning(
- "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection', 'kernel'), ('logit_scale',)]` "
- "are newly initialized. You should probably TRAIN this model on a down-stream task "
- "to be able to use it for predictions and inference."
+ "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection',"
+ " 'kernel'), ('logit_scale',)]` are newly initialized. You should probably TRAIN this model on a"
+ " down-stream task to be able to use it for predictions and inference."
)
return model
diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
index e13c9ca7ef8f74..66340deaf4927f 100755
--- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
+++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
@@ -530,9 +530,9 @@ def from_vision_text_pretrained(
# the projection layers are always newly initialized when loading the model
# using pre-trained vision and text model.
logger.warning(
- "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` "
- "are newly initialized. You should probably TRAIN this model on a down-stream task "
- "to be able to use it for predictions and inference."
+ "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight',"
+ " 'logit_scale']` are newly initialized. You should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
)
return model
diff --git a/src/transformers/models/visual_bert/__init__.py b/src/transformers/models/visual_bert/__init__.py
index 444929e1517911..f7a5390d1348f0 100644
--- a/src/transformers/models/visual_bert/__init__.py
+++ b/src/transformers/models/visual_bert/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
-}
+_import_structure = {"configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_visual_bert"] = [
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"VisualBertForMultipleChoice",
@@ -41,7 +44,12 @@
if TYPE_CHECKING:
from .configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
diff --git a/src/transformers/models/visual_bert/configuration_visual_bert.py b/src/transformers/models/visual_bert/configuration_visual_bert.py
index d4992d5267f8cb..60a3692644d7b3 100644
--- a/src/transformers/models/visual_bert/configuration_visual_bert.py
+++ b/src/transformers/models/visual_bert/configuration_visual_bert.py
@@ -23,13 +23,19 @@
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"uclanlp/visualbert-vqa": "https://huggingface.co/uclanlp/visualbert-vqa/resolve/main/config.json",
"uclanlp/visualbert-vqa-pre": "https://huggingface.co/uclanlp/visualbert-vqa-pre/resolve/main/config.json",
- "uclanlp/visualbert-vqa-coco-pre": "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json",
+ "uclanlp/visualbert-vqa-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json"
+ ),
"uclanlp/visualbert-vcr": "https://huggingface.co/uclanlp/visualbert-vcr/resolve/main/config.json",
"uclanlp/visualbert-vcr-pre": "https://huggingface.co/uclanlp/visualbert-vcr-pre/resolve/main/config.json",
- "uclanlp/visualbert-vcr-coco-pre": "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json",
+ "uclanlp/visualbert-vcr-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json"
+ ),
"uclanlp/visualbert-nlvr2": "https://huggingface.co/uclanlp/visualbert-nlvr2/resolve/main/config.json",
"uclanlp/visualbert-nlvr2-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-pre/resolve/main/config.json",
- "uclanlp/visualbert-nlvr2-coco-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
+ "uclanlp/visualbert-nlvr2-coco-pre": (
+ "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
+ )
# See all VisualBERT models at https://huggingface.co/models?filter=visual_bert
}
diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py
index 506b0c749aee09..118ab1fe5c3ce9 100755
--- a/src/transformers/models/visual_bert/modeling_visual_bert.py
+++ b/src/transformers/models/visual_bert/modeling_visual_bert.py
@@ -158,7 +158,8 @@ def forward(
if (image_text_alignment_mask == 0).sum() != 0:
image_text_alignment_mask[image_text_alignment_mask == 0] = 1 # Avoid divide by zero error
logger.warning(
- "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero error."
+ "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero"
+ " error."
)
visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)
@@ -794,12 +795,12 @@ def forward(
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
+ combined_attention_mask, (batch_size, input_shape + visual_input_shape)
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- attention_mask, [batch_size, input_shape], device
+ attention_mask, (batch_size, input_shape)
)
# Prepare head mask if needed
@@ -928,7 +929,7 @@ def forward(
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
- inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
+ inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
@@ -978,7 +979,7 @@ def forward(
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
- f"The labels provided should have same sequence length as total attention mask. "
+ "The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
@@ -991,7 +992,7 @@ def forward(
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
- f"The labels provided should have same sequence length as total attention mask. "
+ "The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py
index c0331e27d9d50d..b30a9ec15d9d50 100644
--- a/src/transformers/models/vit/__init__.py
+++ b/src/transformers/models/vit/__init__.py
@@ -17,17 +17,32 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
-
-
-_import_structure = {
- "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"],
-}
-
-if is_vision_available():
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vit"] = [
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTForImageClassification",
@@ -36,14 +51,24 @@
"ViTPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vit"] = [
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
@@ -53,10 +78,20 @@
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTOnnxConfig
- if is_vision_available():
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .feature_extraction_vit import ViTFeatureExtractor
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vit import (
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTForImageClassification,
@@ -65,10 +100,20 @@
ViTPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py
index eaa7c4225e8c93..f6e7044057361b 100644
--- a/src/transformers/models/vit/modeling_flax_vit.py
+++ b/src/transformers/models/vit/modeling_flax_vit.py
@@ -143,7 +143,8 @@ class FlaxViTSelfAttention(nn.Module):
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
+ " {self.config.num_attention_heads}"
)
self.query = nn.Dense(
diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py
index 9d478e968cfc0a..46662596612595 100644
--- a/src/transformers/models/vit/modeling_tf_vit.py
+++ b/src/transformers/models/vit/modeling_tf_vit.py
@@ -187,7 +187,8 @@ def call(
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py
index b2fc044fcb09c4..dde36b45ef5bb5 100644
--- a/src/transformers/models/vit/modeling_vit.py
+++ b/src/transformers/models/vit/modeling_vit.py
@@ -186,7 +186,8 @@ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = F
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
@@ -213,7 +214,7 @@ def __init__(self, config: ViTConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -245,7 +246,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -612,7 +613,8 @@ def forward(self, hidden_states):
@add_start_docstrings(
- "ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM `__.",
+ "ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM"
+ " `__.",
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
@@ -687,7 +689,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
- height = width = int(sequence_length**0.5)
+ height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
diff --git a/src/transformers/models/vit_mae/__init__.py b/src/transformers/models/vit_mae/__init__.py
index cc3569b8b7f69c..b785f7f6ee396b 100644
--- a/src/transformers/models/vit_mae/__init__.py
+++ b/src/transformers/models/vit_mae/__init__.py
@@ -17,14 +17,23 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
-_import_structure = {
- "configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"],
-}
+_import_structure = {"configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_vit_mae"] = [
"VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTMAEForPreTraining",
@@ -33,7 +42,12 @@
"ViTMAEPreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining",
"TFViTMAEModel",
@@ -43,7 +57,12 @@
if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_vit_mae import (
VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTMAEForPreTraining,
@@ -52,7 +71,12 @@
ViTMAEPreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
index f464b6665afff0..803a7cccc7e97e 100644
--- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
+++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py
@@ -333,7 +333,8 @@ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py
index 473ccd14feb099..f827978739af6c 100755
--- a/src/transformers/models/vit_mae/modeling_vit_mae.py
+++ b/src/transformers/models/vit_mae/modeling_vit_mae.py
@@ -342,7 +342,7 @@ def __init__(self, config: ViTMAEConfig) -> None:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
+ x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@@ -374,7 +374,7 @@ def forward(
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
+ context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py
index 93783b668283d9..306c2197f4c3b9 100644
--- a/src/transformers/models/wav2vec2/__init__.py
+++ b/src/transformers/models/wav2vec2/__init__.py
@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+)
_import_structure = {
@@ -28,7 +34,12 @@
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForAudioFrameClassification",
@@ -41,7 +52,12 @@
"Wav2Vec2PreTrainedModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
@@ -49,7 +65,12 @@
"TFWav2Vec2PreTrainedModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_wav2vec2"] = [
"FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining",
@@ -64,7 +85,12 @@
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
@@ -77,7 +103,12 @@
Wav2Vec2PreTrainedModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
@@ -85,7 +116,12 @@
TFWav2Vec2PreTrainedModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_wav2vec2 import (
FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining,
diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py
index f675f6799f66ae..6b96d9fc3f6735 100644
--- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py
@@ -78,13 +78,13 @@ class Wav2Vec2Config(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
*conv_dim*.
@@ -156,13 +156,13 @@ class Wav2Vec2Config(PretrainedConfig):
instance of [`Wav2Vec2ForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -288,10 +288,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
index db77a9ea160311..89ae3ad21c2e8c 100644
--- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
@@ -77,7 +77,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
if hf_shape != value.shape:
raise ValueError(
- f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
)
if weight_type == "weight":
@@ -148,14 +149,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
@@ -163,14 +166,16 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
- f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
index 595fb11192ad82..14b1d688c9d7a2 100644
--- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
@@ -171,8 +171,9 @@ def __call__(
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}. "
- f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}."
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
index 1386ca37b075dd..7a3c6dfc5d3066 100644
--- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
@@ -137,7 +137,8 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
@@ -186,7 +187,7 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attentio
batch_size, sequence_length, hidden_size = features_shape
if sequence_length <= 1:
raise ValueError(
- f"`features should have `sequence_length` > 1, but are of shape "
+ "`features should have `sequence_length` > 1, but are of shape "
f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
@@ -386,7 +387,8 @@ def setup(self):
raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
else:
raise ValueError(
- f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group',"
+ " 'layer']"
)
def __call__(self, hidden_states):
@@ -444,7 +446,8 @@ def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
)
dense = partial(
@@ -1081,7 +1084,7 @@ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
>>> input_values = processor(
... ds["speech"][0], sampling_rate=16_000, return_tensors="np"
- >>> ).input_values # Batch size 1
+ ... ).input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
```
"""
@@ -1200,7 +1203,7 @@ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
>>> input_values = processor(
... ds["speech"][0], sampling_rate=16_000, return_tensors="np"
- >>> ).input_values # Batch size 1
+ ... ).input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = jnp.argmax(logits, axis=-1)
diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
index 9bbb908eb03db6..567f20040b9480 100644
--- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
@@ -16,6 +16,7 @@
import inspect
import warnings
+from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
@@ -26,7 +27,6 @@
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...tf_utils import shape_list, stable_softmax
-from ...tokenization_utils_base import BatchEncoding
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -133,12 +133,14 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[i]] = input
else:
raise ValueError(
- f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
+ f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[i]}."
)
- elif isinstance(input_values, (dict, BatchEncoding)):
+ elif isinstance(input_values, Mapping):
if "inputs" in input_values:
warnings.warn(
- "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
+ "The `inputs` argument is deprecated and will be removed in a future version, use `input_values`"
+ " instead.",
FutureWarning,
)
@@ -146,7 +148,8 @@ def input_values_processing(func, config, input_values, **kwargs):
if "decoder_cached_states" in input_values:
warnings.warn(
- "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
+ "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+ " `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
@@ -166,7 +169,8 @@ def input_values_processing(func, config, input_values, **kwargs):
output[parameter_names[0]] = input_values
else:
raise ValueError(
- f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
+ f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for"
+ f" {parameter_names[0]}."
)
for name in parameter_names:
@@ -254,7 +258,8 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
+ f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
@@ -441,9 +446,11 @@ def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
- "Axis " + str(self.axis) + " of "
- "input tensor should have a defined dimension "
- "but the layer received an input with shape " + str(input_shape) + "."
+ "Axis "
+ + str(self.axis)
+ + " of input tensor should have a defined dimension but the layer received an input with shape "
+ + str(input_shape)
+ + "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
@@ -457,22 +464,27 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") cannot be "
- "more than the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") cannot be more than the number of channels ("
+ + str(dim)
+ + ")."
)
if dim % self.groups != 0:
raise ValueError(
- "Number of groups (" + str(self.groups) + ") must be a "
- "multiple of the number of channels (" + str(dim) + ")."
+ "Number of groups ("
+ + str(self.groups)
+ + ") must be a multiple of the number of channels ("
+ + str(dim)
+ + ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
- "You are trying to normalize your batch axis. Do you want to "
- "use tf.layer.batch_normalization instead"
+ "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
@@ -838,7 +850,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
- message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
)
if attention_mask is not None:
@@ -848,7 +863,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
- message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
@@ -864,7 +882,10 @@ def call(
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
- message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
@@ -881,7 +902,10 @@ def call(
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
- message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
)
attn_output = tf.transpose(
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index f58ec9a3363e45..14f5c02e724ea1 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -33,6 +33,8 @@
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
@@ -88,35 +90,6 @@
]
-@dataclass
-class Wav2Vec2BaseModelOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2BaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
@dataclass
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
"""
@@ -159,38 +132,6 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
diversity_loss: Optional[torch.FloatTensor] = None
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
@@ -636,7 +577,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -652,7 +594,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -673,7 +616,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -801,7 +745,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- hidden_states[~attention_mask] = 0.0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -888,7 +833,8 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- hidden_states[~attention_mask] = 0
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
@@ -1022,11 +968,8 @@ def forward(self, hidden_states, mask_time_indices=None):
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
- codevectors = (
- codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
- .sum(-2)
- .view(batch_size, sequence_length, -1)
- )
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
return codevectors, perplexity
@@ -1470,13 +1413,12 @@ def forward(
```python
>>> import torch
- >>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
- >>> import soundfile as sf
- >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
- >>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
+ >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
@@ -1910,6 +1852,7 @@ def __init__(self, config):
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1953,6 +1896,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1985,12 +1929,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
index 53a6cfe1c07aaa..1e77959400e494 100644
--- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py
@@ -61,7 +61,9 @@
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json",
},
"tokenizer_config_file": {
- "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer_config.json",
+ "facebook/wav2vec2-base-960h": (
+ "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer_config.json"
+ ),
},
}
@@ -601,7 +603,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
@@ -717,7 +719,9 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
},
"tokenizer_config_file": {
- "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
+ "facebook/wav2vec2-base-960h": (
+ "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json"
+ ),
},
}
model_input_names = ["input_values", "attention_mask"]
@@ -748,7 +752,8 @@ def __init__(
)
warnings.warn(
- "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
+ "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use"
+ " `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
FutureWarning,
)
@@ -917,6 +922,6 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
diff --git a/src/transformers/models/wav2vec2_conformer/__init__.py b/src/transformers/models/wav2vec2_conformer/__init__.py
new file mode 100644
index 00000000000000..df9fe20e257158
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/__init__.py
@@ -0,0 +1,74 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_wav2vec2_conformer": [
+ "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "Wav2Vec2ConformerConfig",
+ ],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_wav2vec2_conformer"] = [
+ "WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ Wav2Vec2ConformerConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_wav2vec2_conformer import (
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py
new file mode 100644
index 00000000000000..9c5e4d205b9af7
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py
@@ -0,0 +1,357 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Wav2Vec2Conformer model configuration"""
+
+import functools
+import operator
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/wav2vec2-conformer-large-rel-pos": (
+ "https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos/resolve/main/config.json"
+ ),
+}
+
+
+class Wav2Vec2ConformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Wav2Vec2ConformerModel`]. It is used to
+ instantiate an Wav2Vec2Conformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer
+ [facebook/wav2vec2-conformer-large-rel-pos](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*):
+ Vocabulary size of the Wav2Vec2Conformer model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`Wav2Vec2ConformerModel`]. Vocabulary size of the
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
+ method of [`Wav2Vec2ConformerModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ final_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the final projection layer of [`Wav2Vec2ConformerForCTC`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
+ convolutional layers.
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for output of the feature encoder.
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for quantized feature encoder states.
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
+ *conv_dim*.
+ conv_bias (`bool`, *optional*, defaults to `False`):
+ Whether the 1D convolutional layers have a bias.
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
+ embeddings layer.
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+ Number of groups of 1D convolutional positional embeddings layer.
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
+ Recognition](https://arxiv.org/abs/1904.08779).
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
+ mask_time_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the time axis.
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
+ mask_time_min_masks''
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+ True`.
+ mask_feature_length (`int`, *optional*, defaults to 10):
+ Length of vector span along the feature axis.
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+ step, irrespectively of `mask_feature_prob`. Only relevant if
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
+ Number of entries in each quantization codebook (group).
+ num_codevector_groups (`int`, *optional*, defaults to 2):
+ Number of codevector groups for product codevector quantization.
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
+ The temperature *kappa* in the contrastive loss.
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
+ num_negatives (`int`, *optional*, defaults to 100):
+ Number of negative samples for the contrastive loss.
+ codevector_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of the quantized feature vectors.
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
+ Dimensionality of the final projection of both the quantized and the transformer features.
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
+ The weight of the codebook diversity loss component.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`Wav2Vec2ConformerForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`Wav2Vec2ConformerForCTC`].
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+ instance of [`Wav2Vec2ConformerForSequenceClassification`].
+ classifier_proj_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the projection before token mean-pooling for classification.
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
+ xvector_output_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of the *XVector* embedding vectors.
+ add_adapter (`bool`, *optional*, defaults to `False`):
+ Whether a convolutional network should be stacked on top of the Wav2Vec2Conformer Encoder. Can be very
+ useful for warm-starting Wav2Vec2Conformer for SpeechEncoderDecoder models.
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ adapter_stride (`int`, *optional*, defaults to 2):
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+ num_adapter_layers (`int`, *optional*, defaults to 3):
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
+ True`.
+ output_hidden_size (`int`, *optional*):
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
+ if `add_adapter is True`.
+ position_embeddings_type (`str`, *optional*, defaults to `"relative"`):
+ Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left
+ `None` no relative position embedding is applied.
+ rotary_embedding_base (`int`, *optional*, defaults to 10000):
+ If `"rotary"` position embeddings are used, defines the size of the embedding base.
+ max_source_positions (`int`, *optional*, defaults to 5000):
+ if `"relative"` position embeddings are used, defines the maximum source input positions.
+ conv_depthwise_kernel_size (`int`, defaults to 31):
+ Kernel size of convolutional depthwise 1D layer in Conformer blocks.
+ conformer_conv_dropout (`float`, defaults to 0.1):
+ The dropout probability for all convolutional layers in Conformer blocks.
+
+ Example:
+
+ ```python
+ >>> from transformers import Wav2Vec2ConformerModel, Wav2Vec2ConformerConfig
+
+ >>> # Initializing a Wav2Vec2Conformer facebook/wav2vec2-conformer-large-rel-pos style configuration
+ >>> configuration = Wav2Vec2ConformerConfig()
+
+ >>> # Initializing a model from the facebook/wav2vec2-conformer-large-rel-pos style configuration
+ >>> model = Wav2Vec2ConformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "wav2vec2-conformer"
+
+ def __init__(
+ self,
+ vocab_size=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout=0.1,
+ activation_dropout=0.1,
+ attention_dropout=0.1,
+ feat_proj_dropout=0.0,
+ feat_quantizer_dropout=0.0,
+ final_dropout=0.1,
+ layerdrop=0.1,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ feat_extract_norm="group",
+ feat_extract_activation="gelu",
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
+ conv_bias=False,
+ num_conv_pos_embeddings=128,
+ num_conv_pos_embedding_groups=16,
+ apply_spec_augment=True,
+ mask_time_prob=0.05,
+ mask_time_length=10,
+ mask_time_min_masks=2,
+ mask_feature_prob=0.0,
+ mask_feature_length=10,
+ mask_feature_min_masks=0,
+ num_codevectors_per_group=320,
+ num_codevector_groups=2,
+ contrastive_logits_temperature=0.1,
+ num_negatives=100,
+ codevector_dim=256,
+ proj_codevector_dim=256,
+ diversity_loss_weight=0.1,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ use_weighted_layer_sum=False,
+ classifier_proj_size=256,
+ tdnn_dim=(512, 512, 512, 512, 1500),
+ tdnn_kernel=(5, 3, 3, 1, 1),
+ tdnn_dilation=(1, 2, 3, 1, 1),
+ xvector_output_dim=512,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ add_adapter=False,
+ adapter_kernel_size=3,
+ adapter_stride=2,
+ num_adapter_layers=3,
+ output_hidden_size=None,
+ position_embeddings_type="relative",
+ rotary_embedding_base=10000,
+ max_source_positions=5000,
+ conv_depthwise_kernel_size=31,
+ conformer_conv_dropout=0.1,
+ **kwargs
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = list(conv_dim)
+ self.conv_stride = list(conv_stride)
+ self.conv_kernel = list(conv_kernel)
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_feat_extract_layers = len(self.conv_dim)
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.feat_proj_dropout = feat_proj_dropout
+ self.final_dropout = final_dropout
+ self.layerdrop = layerdrop
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.use_weighted_layer_sum = use_weighted_layer_sum
+ self.max_source_positions = max_source_positions
+ self.position_embeddings_type = position_embeddings_type
+ self.rotary_embedding_base = rotary_embedding_base
+
+ if (
+ (len(self.conv_stride) != self.num_feat_extract_layers)
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
+ ):
+ raise ValueError(
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ )
+
+ # Conformer-block related
+ self.conv_depthwise_kernel_size = conv_depthwise_kernel_size
+ self.conformer_conv_dropout = conformer_conv_dropout
+
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
+ self.apply_spec_augment = apply_spec_augment
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.mask_time_min_masks = mask_time_min_masks
+ self.mask_feature_prob = mask_feature_prob
+ self.mask_feature_length = mask_feature_length
+ self.mask_feature_min_masks = mask_feature_min_masks
+
+ # parameters for pretraining with codevector quantized representations
+ self.num_codevectors_per_group = num_codevectors_per_group
+ self.num_codevector_groups = num_codevector_groups
+ self.contrastive_logits_temperature = contrastive_logits_temperature
+ self.feat_quantizer_dropout = feat_quantizer_dropout
+ self.num_negatives = num_negatives
+ self.codevector_dim = codevector_dim
+ self.proj_codevector_dim = proj_codevector_dim
+ self.diversity_loss_weight = diversity_loss_weight
+
+ # ctc loss
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # adapter
+ self.add_adapter = add_adapter
+ self.adapter_kernel_size = adapter_kernel_size
+ self.adapter_stride = adapter_stride
+ self.num_adapter_layers = num_adapter_layers
+ self.output_hidden_size = output_hidden_size or hidden_size
+
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
+ self.classifier_proj_size = classifier_proj_size
+
+ # XVector-specific parameters. Feel free to ignore for other classes.
+ self.tdnn_dim = list(tdnn_dim)
+ self.tdnn_kernel = list(tdnn_kernel)
+ self.tdnn_dilation = list(tdnn_dilation)
+ self.xvector_output_dim = xvector_output_dim
+
+ @property
+ def inputs_to_logits_ratio(self):
+ return functools.reduce(operator.mul, self.conv_stride, 1)
diff --git a/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 00000000000000..26ccf9239b6115
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,307 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Wav2Vec2Conformer checkpoint."""
+
+
+import argparse
+import json
+import os
+
+import fairseq
+import torch
+from fairseq.data import Dictionary
+
+from transformers import (
+ Wav2Vec2ConformerConfig,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2CTCTokenizer,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ logging,
+)
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+MAPPING = {
+ "post_extract_proj": "feature_projection.projection",
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
+ "self_attn.linear_k": "encoder.layers.*.self_attn.linear_k",
+ "self_attn.linear_v": "encoder.layers.*.self_attn.linear_v",
+ "self_attn.linear_q": "encoder.layers.*.self_attn.linear_q",
+ "self_attn.pos_bias_u": "encoder.layers.*.self_attn.pos_bias_u",
+ "self_attn.pos_bias_v": "encoder.layers.*.self_attn.pos_bias_v",
+ "self_attn.linear_out": "encoder.layers.*.self_attn.linear_out",
+ "self_attn.linear_pos": "encoder.layers.*.self_attn.linear_pos",
+ "self_attn.rotary_emb": "encoder.embed_positions",
+ "self_attn_layer_norm": "encoder.layers.*.self_attn_layer_norm",
+ "conv_module.pointwise_conv1": "encoder.layers.*.conv_module.pointwise_conv1",
+ "conv_module.pointwise_conv2": "encoder.layers.*.conv_module.pointwise_conv2",
+ "conv_module.depthwise_conv": "encoder.layers.*.conv_module.depthwise_conv",
+ "conv_module.batch_norm": "encoder.layers.*.conv_module.batch_norm",
+ "conv_module.layer_norm": "encoder.layers.*.conv_module.layer_norm",
+ "ffn1.w_1": "encoder.layers.*.ffn1.intermediate_dense",
+ "ffn1.w_2": "encoder.layers.*.ffn1.output_dense",
+ "ffn1.layer_norm": "encoder.layers.*.ffn1_layer_norm",
+ "ffn2.w_1": "encoder.layers.*.ffn2.intermediate_dense",
+ "ffn2.w_2": "encoder.layers.*.ffn2.output_dense",
+ "ffn2.layer_norm": "encoder.layers.*.ffn2_layer_norm",
+ "final_layer_norm": "encoder.layers.*.final_layer_norm",
+ "encoder.layer_norm": "encoder.layer_norm",
+ "w2v_model.layer_norm": "feature_projection.layer_norm",
+ "quantizer.weight_proj": "quantizer.weight_proj",
+ "quantizer.vars": "quantizer.codevectors",
+ "project_q": "project_q",
+ "final_proj": "project_hid",
+ "w2v_encoder.proj": "lm_head",
+ "mask_emb": "masked_spec_embed",
+}
+TOP_LEVEL_KEYS = [
+ "lm_head",
+ "quantizer.weight_proj",
+ "quantizer.codevectors",
+ "project_q",
+ "project_hid",
+]
+
+
+def set_recursively(hf_pointer, key, value, full_name, weight_type):
+ for attribute in key.split("."):
+ hf_pointer = getattr(hf_pointer, attribute)
+
+ if weight_type is not None:
+ hf_shape = getattr(hf_pointer, weight_type).shape
+ else:
+ hf_shape = hf_pointer.shape
+
+ if hf_shape != value.shape:
+ raise ValueError(
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
+
+ if weight_type == "weight":
+ hf_pointer.weight.data = value
+ elif weight_type == "weight_g":
+ hf_pointer.weight_g.data = value
+ elif weight_type == "weight_v":
+ hf_pointer.weight_v.data = value
+ elif weight_type == "bias":
+ hf_pointer.bias.data = value
+ elif weight_type == "running_mean":
+ hf_pointer.running_mean.data = value
+ elif weight_type == "running_var":
+ hf_pointer.running_var.data = value
+ elif weight_type == "num_batches_tracked":
+ hf_pointer.num_batches_tracked.data = value
+ elif weight_type == "inv_freq":
+ hf_pointer.inv_freq.data = value
+ else:
+ hf_pointer.data = value
+
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
+
+
+def recursively_load_weights(fairseq_model, hf_model, is_headless):
+ unused_weights = []
+ fairseq_dict = fairseq_model.state_dict()
+
+ feature_extractor = hf_model.wav2vec2_conformer.feature_extractor
+
+ for name, value in fairseq_dict.items():
+ is_used = False
+ if "conv_layers" in name:
+ load_conv_layer(
+ name,
+ value,
+ feature_extractor,
+ unused_weights,
+ hf_model.config.feat_extract_norm == "group",
+ )
+ is_used = True
+ else:
+ for key, mapped_key in MAPPING.items():
+ mapped_key = "wav2vec2_conformer." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
+ is_used = True
+ if "*" in mapped_key:
+ layer_index = name.split(key)[0].split(".")[-2]
+ mapped_key = mapped_key.replace("*", layer_index)
+ if "pos_bias_u" in name:
+ weight_type = None
+ elif "pos_bias_v" in name:
+ weight_type = None
+ elif "weight_g" in name:
+ weight_type = "weight_g"
+ elif "weight_v" in name:
+ weight_type = "weight_v"
+ elif "bias" in name:
+ weight_type = "bias"
+ elif "weight" in name:
+ # TODO: don't match quantizer.weight_proj
+ weight_type = "weight"
+ elif "running_mean" in name:
+ weight_type = "running_mean"
+ elif "inv_freq" in name:
+ weight_type = "inv_freq"
+ elif "running_var" in name:
+ weight_type = "running_var"
+ elif "num_batches_tracked" in name:
+ weight_type = "num_batches_tracked"
+ else:
+ weight_type = None
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
+ continue
+ if not is_used:
+ unused_weights.append(name)
+
+ logger.warning(f"Unused weights: {unused_weights}")
+
+
+# Copied from transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.load_conv_layer
+def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
+ name = full_name.split("conv_layers.")[-1]
+ items = name.split(".")
+ layer_id = int(items[0])
+ type_id = int(items[1])
+
+ if type_id == 0:
+ if "bias" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].conv.bias.data = value
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
+ elif "weight" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].conv.weight.data = value
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
+ elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
+ if "bias" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
+ elif "weight" in name:
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
+ raise ValueError(
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
+ )
+ feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
+ else:
+ unused_weights.append(full_name)
+
+
+@torch.no_grad()
+def convert_wav2vec2_conformer_checkpoint(
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
+):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = Wav2Vec2ConformerConfig.from_pretrained(config_path, hidden_act="swish")
+ else:
+ config = Wav2Vec2ConformerConfig()
+
+ if "rope" in checkpoint_path:
+ config.position_embeddings_type = "rotary"
+
+ if is_finetuned:
+ if dict_path:
+ target_dict = Dictionary.load(dict_path)
+
+ # important change bos & pad token id since CTC symbol is and
+ # not as in fairseq
+ config.bos_token_id = target_dict.pad_index
+ config.pad_token_id = target_dict.bos_index
+ config.eos_token_id = target_dict.eos_index
+ config.vocab_size = len(target_dict.symbols)
+ vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
+ if not os.path.isdir(pytorch_dump_folder_path):
+ logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
+ return
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
+ vocab_dict = target_dict.indices
+
+ # fairseq has the and switched
+ vocab_dict[""] = 0
+ vocab_dict[""] = 1
+ with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
+ json.dump(vocab_dict, vocab_handle)
+ tokenizer = Wav2Vec2CTCTokenizer(
+ vocab_path,
+ unk_token=target_dict.unk_word,
+ pad_token=target_dict.pad_word,
+ bos_token=target_dict.bos_word,
+ eos_token=target_dict.eos_word,
+ word_delimiter_token="|",
+ do_lower_case=False,
+ )
+ return_attention_mask = True if config.feat_extract_norm == "layer" else False
+ feature_extractor = Wav2Vec2FeatureExtractor(
+ feature_size=1,
+ sampling_rate=16000,
+ padding_value=0,
+ do_normalize=True,
+ return_attention_mask=return_attention_mask,
+ )
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ hf_wav2vec = Wav2Vec2ConformerForCTC(config)
+ else:
+ hf_wav2vec = Wav2Vec2ConformerForPreTraining(config)
+
+ if is_finetuned:
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
+ [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
+ )
+ else:
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
+
+ model = model[0].eval()
+
+ recursively_load_weights(model, hf_wav2vec, not is_finetuned)
+
+ hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ parser.add_argument(
+ "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
+ )
+ args = parser.parse_args()
+ convert_wav2vec2_conformer_checkpoint(
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
+ )
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
new file mode 100644
index 00000000000000..d4972f65b43f1b
--- /dev/null
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -0,0 +1,2126 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import torch_int_div
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+# Audio class docstring
+_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
+_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-seq-class"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
+_SEQ_CLASS_EXPECTED_LOSS = 0.68
+
+# Frame class docstring
+_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-frame-class"
+_FRAME_EXPECTED_OUTPUT = [1, 0]
+
+# Speaker Verification docstring
+_XVECTOR_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-xvector"
+_XVECTOR_EXPECTED_OUTPUT = 1.0
+
+
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/wav2vec2-conformer-large-rel-pos",
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
+]
+
+
+@dataclass
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: torch.FloatTensor = None
+ projected_quantized_states: torch.FloatTensor = None
+ codevector_perplexity: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ contrastive_loss: Optional[torch.FloatTensor] = None
+ diversity_loss: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.sum(-1).detach().tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+ """
+ Sample `num_negatives` vectors from feature vectors.
+ """
+ batch_size, sequence_length = features_shape
+
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+ sequence_length_range = np.arange(sequence_length)
+
+ # get `num_negatives` random vector indices from the same utterance
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+ mask_time_indices = (
+ mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
+ )
+
+ for batch_idx in range(batch_size):
+ high = mask_time_indices[batch_idx].sum() - 1
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+ # avoid sampling the same positive vector, but keep the distribution uniform
+ sampled_indices[sampled_indices >= feature_indices] += 1
+
+ # remap to actual indices
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+ # correct for batch size
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+ return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
+ else:
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
+ return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (iWav2Vec2Conformer
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ if self._requires_grad and self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(conv_layer),
+ hidden_states,
+ )
+ else:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerConvolutionModule(nn.Module):
+ """Convolution block used in the conformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.glu = torch.nn.GLU(dim=1)
+ self.depthwise_conv = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ config.conv_depthwise_kernel_size,
+ stride=1,
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.layer_norm(hidden_states)
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism
+ # => (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+ """Construct an Wav2Vec2ConformerSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_size = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ if self.position_embeddings_type == "relative":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type =="
+ " 'relative'"
+ )
+ # apply relative_position_embeddings to qk scores
+ # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860
+ scores = self._apply_relative_embeddings(
+ query=query, key=key, relative_position_embeddings=relative_position_embeddings
+ )
+ else:
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
+
+ # apply attention_mask if necessary
+ if attention_mask is not None:
+ scores = scores + attention_mask
+
+ # => (batch, head, time1, time2)
+ probs = torch.softmax(scores, dim=-1)
+ probs = self.dropout(probs)
+
+ # => (batch, head, time1, d_k)
+ hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states[~attention_mask] = 0.0
+
+ # extend attention_mask
+ attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ # create gradient checkpointing function
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
+ """
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_groups = config.num_codevector_groups
+ self.num_vars = config.num_codevectors_per_group
+
+ if config.codevector_dim % self.num_groups != 0:
+ raise ValueError(
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ )
+
+ # storage for codebook variables (codewords)
+ self.codevectors = nn.Parameter(
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+ )
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+ # can be decayed for training
+ self.temperature = 2
+
+ @staticmethod
+ def _compute_perplexity(probs, mask=None):
+ if mask is not None:
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
+ marginal_probs = probs.sum(dim=0) / mask.sum()
+ else:
+ marginal_probs = probs.mean(dim=0)
+
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states, mask_time_indices=None):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2ConformerConfig
+ base_model_prefix = "wav2vec2_conformer"
+ main_input_name = "input_values"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # gumbel softmax requires special init
+ if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch_int_div(input_length - kernel_size, stride) + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
+ module.gradient_checkpointing = value
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
+ and conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [wav2vec2_conformer-base](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos),
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+ that these models also yield slightly different results depending on whether `input_values` is padded or
+ not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2ConformerEncoder(config)
+
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # make sure that project_hid & project_q are initialized like normal linear layers
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 0.1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
+ target_features
+ )
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.BoolTensor] = None,
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+ Required input for pre-training.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
+ >>> from datasets import load_dataset
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
+
+ >>> # compute masked indices
+ >>> batch_size, raw_sequence_length = input_values.shape
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
+ >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
+ >>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
+
+ >>> with torch.no_grad():
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
+
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ >>> # show that cosine similarity is much higher than random
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
+ tensor(True)
+
+ >>> # for contrastive loss training model should be put into train mode
+ >>> model = model.train()
+ >>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if mask_time_indices is not None:
+ mask_time_indices = mask_time_indices.to(torch.bool)
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ mask_time_indices=mask_time_indices,
+ return_dict=return_dict,
+ )
+
+ # 1. project all transformed features (including masked) to final vq dim
+ transformer_features = self.project_hid(outputs[0])
+
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+
+ if attention_mask is not None:
+ # compute reduced attention_mask correponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ quantized_features, codevector_perplexity = self.quantizer(
+ extract_features, mask_time_indices=mask_time_indices
+ )
+ quantized_features = self.project_q(quantized_features)
+
+ loss = contrastive_loss = diversity_loss = None
+ if sampled_negative_indices is not None:
+ batch_size, sequence_length, hidden_size = quantized_features.shape
+
+ # for training, we sample negatives
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
+ # sample negative quantized vectors BTC => (BxT)C
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
+ sampled_negative_indices.long().view(-1)
+ ]
+ negative_quantized_features = negative_quantized_features.view(
+ batch_size, sequence_length, -1, hidden_size
+ ).permute(2, 0, 1, 3)
+
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
+ logits = self.compute_contrastive_logits(
+ quantized_features[None, :],
+ negative_quantized_features,
+ transformer_features,
+ self.config.contrastive_logits_temperature,
+ )
+
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
+ # its cosine similarity will be masked
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
+
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
+
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
+ # 7. compute diversity loss: \mathbf{L}_d
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
+
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
+
+ if not return_dict:
+ if loss is not None:
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return Wav2Vec2ConformerForPreTrainingOutput(
+ loss=loss,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ contrastive_loss=contrastive_loss,
+ diversity_loss=diversity_loss,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
+ " vocabulary size of the language model head. Please instantiate the model as follows:"
+ " `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
+ " your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_PROCESSOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters"
+ " (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ hidden_states[~padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters"
+ " (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_FRAME_CLASS_CHECKPOINT,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_FRAME_EXPECTED_OUTPUT,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super(AMSoftmaxLoss, self).__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = nn.functional.unfold(
+ hidden_states,
+ (self.kernel_size, self.in_conv_dim),
+ stride=(1, self.in_conv_dim),
+ dilation=(self.dilation, 1),
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.kernel(hidden_states)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_XVECTOR_CHECKPOINT,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_XVECTOR_EXPECTED_OUTPUT,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/models/wav2vec2_phoneme/__init__.py b/src/transformers/models/wav2vec2_phoneme/__init__.py
index 4d6ea18a330eb1..84dc9942d515b5 100644
--- a/src/transformers/models/wav2vec2_phoneme/__init__.py
+++ b/src/transformers/models/wav2vec2_phoneme/__init__.py
@@ -20,11 +20,7 @@
from ...utils import _LazyModule
-# fmt: off
-_import_structure = {
- "tokenization_wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"]
-}
-# fmt: on
+_import_structure = {"tokenization_wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
index 6bd355645e5a3f..c983c4be826430 100644
--- a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
+++ b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
@@ -55,10 +55,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "facebook/wav2vec2-lv-60-espeak-cv-ft": "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/vocab.json",
+ "facebook/wav2vec2-lv-60-espeak-cv-ft": (
+ "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/vocab.json"
+ ),
},
"tokenizer_config_file": {
- "facebook/wav2vec2-lv-60-espeak-cv-ft": "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/tokenizer_config.json",
+ "facebook/wav2vec2-lv-60-espeak-cv-ft": (
+ "https://huggingface.co/facebook/wav2vec2-lv-60-espeak-cv-ft/resolve/main/tokenizer_config.json"
+ ),
},
}
@@ -369,7 +373,7 @@ def convert_tokens_to_string(
if len(char_offsets) != len(processed_chars):
raise ValueError(
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
- f" have to be of the same length, but are: `len(offsets)`: "
+ " have to be of the same length, but are: `len(offsets)`: "
f"{len(char_offsets)} and `len(processed_tokens)`: {len(processed_chars)}"
)
@@ -564,7 +568,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
@@ -600,7 +604,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
tokens_to_add = []
for token in new_tokens:
if not isinstance(token, str):
- raise ValueError(f"Token {token} has to be of type string, but is " f"of type {type(token)}.")
+ raise ValueError(f"Token {token} has to be of type string, but is of type {type(token)}.")
assert isinstance(token, str)
if (
token != self.unk_token
diff --git a/src/transformers/models/wav2vec2_with_lm/__init__.py b/src/transformers/models/wav2vec2_with_lm/__init__.py
index 8730f3508e3087..174946ae10181a 100644
--- a/src/transformers/models/wav2vec2_with_lm/__init__.py
+++ b/src/transformers/models/wav2vec2_with_lm/__init__.py
@@ -20,11 +20,7 @@
from ...utils import _LazyModule
-# fmt: off
-_import_structure = {
- "processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
-}
-# fmt: on
+_import_structure = {"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]}
if TYPE_CHECKING:
diff --git a/src/transformers/models/wavlm/__init__.py b/src/transformers/models/wavlm/__init__.py
index 576bbaf83cdfc6..9cd64b25dafaf0 100644
--- a/src/transformers/models/wavlm/__init__.py
+++ b/src/transformers/models/wavlm/__init__.py
@@ -17,14 +17,17 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
-_import_structure = {
- "configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"],
-}
+_import_structure = {"configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_wavlm"] = [
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"WavLMForAudioFrameClassification",
@@ -38,7 +41,12 @@
if TYPE_CHECKING:
from .configuration_wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification,
diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py
index d7f0b7047030cf..3257d1e986cdcd 100644
--- a/src/transformers/models/wavlm/configuration_wavlm.py
+++ b/src/transformers/models/wavlm/configuration_wavlm.py
@@ -77,13 +77,13 @@ class WavLMConfig(PretrainedConfig):
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
The dropout probabilitiy for quantized feature encoder states.
- conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
- conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
of *conv_stride* defines the number of convolutional layers and has to match the the length of *conv_dim*.
- conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
length of *conv_kernel* defines the number of convolutional layers and has to match the the length of
*conv_dim*.
@@ -146,13 +146,13 @@ class WavLMConfig(PretrainedConfig):
instance of [`WavLMForSequenceClassification`].
classifier_proj_size (`int`, *optional*, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
- tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
- tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
- tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
xvector_output_dim (`int`, *optional*, defaults to 512):
@@ -290,10 +290,10 @@ def __init__(
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
- "Configuration for convolutional layers is incorrect. "
- "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
- f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
- f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
diff --git a/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
index 8523fa87eba820..91758cc9595290 100644
--- a/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
@@ -74,9 +74,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
- assert (
- hf_shape == value.shape
- ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
+ assert hf_shape == value.shape, (
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
if weight_type == "weight":
hf_pointer.weight.data = value
@@ -144,28 +145,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape, (
+ f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was"
+ " found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
- assert (
- value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
- ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ assert value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape, (
+ f"{full_name} has size {value.shape}, but"
+ f" {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
+ )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py
index c2eb193160ad37..d945545af4818d 100755
--- a/src/transformers/models/wavlm/modeling_wavlm.py
+++ b/src/transformers/models/wavlm/modeling_wavlm.py
@@ -16,7 +16,6 @@
import math
import warnings
-from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
@@ -28,16 +27,17 @@
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
-from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
+from ...modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
-from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
-)
+from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_wavlm import WavLMConfig
@@ -80,67 +80,6 @@
]
-@dataclass
-class WavLMBaseModelOutput(ModelOutput):
- """
- Output type of [`WavLMBaseModelOutput`], with potential hidden states and attentions.
-
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
- Sequence of extracted feature vectors of the last convolutional layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- last_hidden_state: torch.FloatTensor = None
- extract_features: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
-@dataclass
-class XVectorOutput(ModelOutput):
- """
- Output type of [`Wav2Vec2ForXVector`].
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Classification hidden states before AMSoftmax.
- embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
- Utterance embeddings used for vector similarity-based retrieval.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- embeddings: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
-
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
@@ -620,12 +559,12 @@ def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> t
relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
- relative_postion_if_large = (max_exact + relative_positions_if_large).to(torch.long)
- relative_postion_if_large = torch.min(
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
- relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+ relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
return relative_buckets
@@ -1184,7 +1123,7 @@ def _set_gradient_checkpointing(self, module, value=False):
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
WAVLM_START_DOCSTRING,
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
class WavLMModel(WavLMPreTrainedModel):
def __init__(self, config: WavLMConfig):
super().__init__(config)
@@ -1275,7 +1214,7 @@ def _mask_hidden_states(
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=WavLMBaseModelOutput,
+ output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="audio",
expected_output=_EXPECTED_OUTPUT_SHAPE,
@@ -1288,7 +1227,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, WavLMBaseModelOutput]:
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1325,7 +1264,7 @@ def forward(
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
- return WavLMBaseModelOutput(
+ return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
@@ -1606,6 +1545,7 @@ def __init__(self, config):
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
self.init_weights()
@@ -1649,6 +1589,7 @@ def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
@@ -1681,12 +1622,17 @@ def forward(
logits = self.classifier(hidden_states)
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return output
return TokenClassifierOutput(
- loss=None,
+ loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
diff --git a/src/transformers/models/xglm/__init__.py b/src/transformers/models/xglm/__init__.py
index d5934dea6666ac..2ab60e4cb4bbc9 100644
--- a/src/transformers/models/xglm/__init__.py
+++ b/src/transformers/models/xglm/__init__.py
@@ -19,6 +19,7 @@
# rely on isort to merge the imports
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
-}
+_import_structure = {"configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xglm"] = ["XGLMTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xglm"] = [
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XGLMForCausalLM",
@@ -46,7 +60,12 @@
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_xglm"] = [
"FlaxXGLMForCausalLM",
"FlaxXGLMModel",
@@ -57,16 +76,36 @@
if TYPE_CHECKING:
from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xglm import XGLMTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xglm_fast import XGLMTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index f26c7fa818390f..fa2a8c6eb60667 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
@@ -142,7 +142,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
@@ -330,7 +330,8 @@ def forward(
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
)
if attention_mask is not None:
@@ -346,7 +347,8 @@ def forward(
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@@ -367,7 +369,8 @@ def forward(
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
@@ -574,7 +577,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
- ).to(self.device)
+ ).to(inputs_embeds.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -709,7 +712,7 @@ def forward(
hidden_states = inputs_embeds + positions
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -722,7 +725,8 @@ def forward(
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -738,7 +742,8 @@ def forward(
if use_cache:
logger.warning(
- "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
+ "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache ="
+ " False`..."
)
use_cache = False
diff --git a/src/transformers/models/xlm/__init__.py b/src/transformers/models/xlm/__init__.py
index f0a42e244e7edb..de9be348b94c63 100644
--- a/src/transformers/models/xlm/__init__.py
+++ b/src/transformers/models/xlm/__init__.py
@@ -18,15 +18,20 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tf_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {
- "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"],
+ "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
"tokenization_xlm": ["XLMTokenizer"],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm"] = [
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMForMultipleChoice",
@@ -39,7 +44,12 @@
"XLMWithLMHeadModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlm"] = [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMForMultipleChoice",
@@ -54,10 +64,15 @@
if TYPE_CHECKING:
- from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
+ from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
from .tokenization_xlm import XLMTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice,
@@ -70,7 +85,12 @@
XLMWithLMHeadModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
diff --git a/src/transformers/models/xlm/configuration_xlm.py b/src/transformers/models/xlm/configuration_xlm.py
index d6f70c6671cc7c..e14ad2ec6cae34 100644
--- a/src/transformers/models/xlm/configuration_xlm.py
+++ b/src/transformers/models/xlm/configuration_xlm.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" XLM configuration"""
+from collections import OrderedDict
+from typing import Mapping
from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
from ...utils import logging
@@ -228,3 +231,20 @@ def __init__(
self.n_words = kwargs["n_words"]
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
+
+
+# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
+class XLMOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py
index 24d32f798f3d1a..fa3a54b6cc078a 100644
--- a/src/transformers/models/xlm/modeling_tf_xlm.py
+++ b/src/transformers/models/xlm/modeling_tf_xlm.py
@@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = padding_mask
else:
# assert lengths.max().item() <= slen
- alen = tf.range(slen)
- mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
+ alen = tf.range(slen, dtype=lengths.dtype)
+ mask = alen < tf.expand_dims(lengths, axis=1)
# attention mask is the same as mask, or triangular inferior attention (causal)
if causal:
diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py
index ed0817afbb8955..ebb3c503475c02 100755
--- a/src/transformers/models/xlm/modeling_xlm.py
+++ b/src/transformers/models/xlm/modeling_xlm.py
@@ -1039,7 +1039,7 @@ def forward(
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
... 0
- >>> ) # Batch size 1
+ ... ) # Batch size 1
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
diff --git a/src/transformers/models/xlm/tokenization_xlm.py b/src/transformers/models/xlm/tokenization_xlm.py
index 7519a514c9625b..bd7b58eb053b0e 100644
--- a/src/transformers/models/xlm/tokenization_xlm.py
+++ b/src/transformers/models/xlm/tokenization_xlm.py
@@ -22,8 +22,6 @@
import unicodedata
from typing import List, Optional, Tuple
-import sacremoses as sm
-
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
@@ -629,6 +627,16 @@ def __init__(
**kwargs,
)
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
+
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
@@ -659,7 +667,7 @@ def do_lower_case(self):
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
- punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
+ punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
@@ -667,7 +675,7 @@ def moses_punct_norm(self, text, lang):
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
- moses_tokenizer = sm.MosesTokenizer(lang=lang)
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
@@ -689,7 +697,8 @@ def ja_tokenize(self, text):
)
except (AttributeError, ImportError):
logger.error(
- "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
+ "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
+ " (https://github.com/chezou/Mykytea-python) with the following steps"
)
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i")
@@ -793,7 +802,8 @@ def _tokenize(self, text, lang="en", bypass_tokenizer=False):
"""
if lang and self.lang2id and lang not in self.lang2id:
logger.error(
- "Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
+ "Supplied language code not found in lang2id mapping. Please check that your language is supported by"
+ " the loaded pretrained model."
)
if bypass_tokenizer:
text = text.split()
@@ -955,7 +965,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
)
with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
@@ -970,3 +980,21 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1
return vocab_file, merge_file
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sm"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ try:
+ import sacremoses
+ except ImportError:
+ raise ImportError(
+ "You need to install sacremoses to use XLMTokenizer. "
+ "See https://pypi.org/project/sacremoses/ for installation."
+ )
+
+ self.sm = sacremoses
diff --git a/src/transformers/models/xlm_prophetnet/__init__.py b/src/transformers/models/xlm_prophetnet/__init__.py
index fe69b506076519..8fbec3d400ed59 100644
--- a/src/transformers/models/xlm_prophetnet/__init__.py
+++ b/src/transformers/models/xlm_prophetnet/__init__.py
@@ -17,20 +17,27 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_sentencepiece_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available
_import_structure = {
- "configuration_xlm_prophetnet": [
- "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP",
- "XLMProphetNetConfig",
- ],
+ "configuration_xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_prophetnet"] = ["XLMProphetNetTokenizer"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_prophetnet"] = [
"XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMProphetNetDecoder",
@@ -44,10 +51,20 @@
if TYPE_CHECKING:
from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMProphetNetDecoder,
diff --git a/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
index 2c3d21bd283c26..3025ed29f64328 100644
--- a/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
+++ b/src/transformers/models/xlm_prophetnet/configuration_xlm_prophetnet.py
@@ -22,7 +22,9 @@
logger = logging.get_logger(__name__)
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
- "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json",
+ "microsoft/xprophetnet-large-wiki100-cased": (
+ "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
index dfb7b394915b5b..8961fbbfc37403 100644
--- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
+++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py
@@ -98,7 +98,7 @@ class XLMProphetNetModel(ProphetNetModel):
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
@@ -124,7 +124,7 @@ class XLMProphetNetForConditionalGeneration(ProphetNetForConditionalGeneration):
>>> input_ids = tokenizer(
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
- >>> ).input_ids # Batch size 1
+ ... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
diff --git a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
index 48f68238f126e4..af8308287939f8 100644
--- a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
+++ b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py
@@ -30,7 +30,9 @@
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
- "microsoft/xprophetnet-large-wiki100-cased": "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/prophetnet.tokenizer",
+ "microsoft/xprophetnet-large-wiki100-cased": (
+ "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/prophetnet.tokenizer"
+ ),
}
}
@@ -159,8 +161,8 @@ def __init__(
import sentencepiece as spm
except ImportError:
logger.warning(
- "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece "
- "pip install sentencepiece"
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
)
raise
@@ -198,8 +200,8 @@ def __setstate__(self, d):
import sentencepiece as spm
except ImportError:
logger.warning(
- "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece "
- "pip install sentencepiece"
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
)
raise
diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py
index a29a400c8b7d9c..60d26c1314847b 100644
--- a/src/transformers/models/xlm_roberta/__init__.py
+++ b/src/transformers/models/xlm_roberta/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
@@ -36,13 +37,28 @@
],
}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_roberta"] = ["XLMRobertaTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlm_roberta_fast"] = ["XLMRobertaTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_roberta"] = [
"XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaForCausalLM",
@@ -54,7 +70,12 @@
"XLMRobertaModel",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlm_roberta"] = [
"TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMRobertaForMaskedLM",
@@ -65,7 +86,12 @@
"TFXLMRobertaModel",
]
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_xlm_roberta"] = [
"FlaxXLMRobertaForMaskedLM",
"FlaxXLMRobertaForMultipleChoice",
@@ -82,13 +108,28 @@
XLMRobertaOnnxConfig,
)
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_roberta import XLMRobertaTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM,
@@ -100,7 +141,12 @@
XLMRobertaModel,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM,
@@ -111,7 +157,12 @@
TFXLMRobertaModel,
)
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_flax_xlm_roberta import (
FlaxXLMRobertaForMaskedLM,
FlaxXLMRobertaForMultipleChoice,
diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
index c1469bfca4cfce..194b38a8c181ec 100644
--- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py
@@ -27,10 +27,18 @@
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/config.json",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json"
+ ),
}
diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
index 072933a12ea6ac..40928d8dc30623 100644
--- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
@@ -35,10 +35,18 @@
"vocab_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model"
+ ),
}
}
diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
index 119d2fa080f2de..f99e3c086a88c5 100644
--- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
+++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py
@@ -38,18 +38,34 @@
"vocab_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model"
+ ),
},
"tokenizer_file": {
"xlm-roberta-base": "https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json",
"xlm-roberta-large": "https://huggingface.co/xlm-roberta-large/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll02-dutch": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll02-spanish": "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll03-english": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json",
- "xlm-roberta-large-finetuned-conll03-german": "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/tokenizer.json",
+ "xlm-roberta-large-finetuned-conll02-dutch": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll02-spanish": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-english": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json"
+ ),
+ "xlm-roberta-large-finetuned-conll03-german": (
+ "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/tokenizer.json"
+ ),
},
}
diff --git a/src/transformers/models/xlm_roberta_xl/__init__.py b/src/transformers/models/xlm_roberta_xl/__init__.py
index 765a235f29e364..3140e3bd226718 100644
--- a/src/transformers/models/xlm_roberta_xl/__init__.py
+++ b/src/transformers/models/xlm_roberta_xl/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
@@ -29,7 +29,12 @@
],
}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlm_roberta_xl"] = [
"XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaXLForCausalLM",
@@ -49,7 +54,12 @@
XLMRobertaXLOnnxConfig,
)
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlm_roberta_xl import (
XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaXLForCausalLM,
diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
index e6c3ac3ec8c79a..70dd4221573be8 100644
--- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
+++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
@@ -176,7 +176,7 @@ def __init__(self, config, position_embedding_type=None):
self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@@ -415,7 +415,8 @@ def forward(
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
@@ -788,7 +789,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
diff --git a/src/transformers/models/xlnet/__init__.py b/src/transformers/models/xlnet/__init__.py
index 599448a271df50..d01edf267cc1d7 100644
--- a/src/transformers/models/xlnet/__init__.py
+++ b/src/transformers/models/xlnet/__init__.py
@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING
from ...utils import (
+ OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
@@ -27,17 +28,30 @@
)
-_import_structure = {
- "configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
-}
+_import_structure = {"configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"]}
-if is_sentencepiece_available():
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlnet"] = ["XLNetTokenizer"]
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_xlnet_fast"] = ["XLNetTokenizerFast"]
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_xlnet"] = [
"XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLNetForMultipleChoice",
@@ -51,7 +65,12 @@
"load_tf_weights_in_xlnet",
]
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_xlnet"] = [
"TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLNetForMultipleChoice",
@@ -68,13 +87,28 @@
if TYPE_CHECKING:
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
- if is_sentencepiece_available():
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlnet import XLNetTokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_xlnet_fast import XLNetTokenizerFast
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice,
@@ -88,7 +122,12 @@
load_tf_weights_in_xlnet,
)
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice,
diff --git a/src/transformers/models/xlnet/configuration_xlnet.py b/src/transformers/models/xlnet/configuration_xlnet.py
index bc6f0f68356f41..5448f9248ced92 100644
--- a/src/transformers/models/xlnet/configuration_xlnet.py
+++ b/src/transformers/models/xlnet/configuration_xlnet.py
@@ -219,7 +219,8 @@ def __init__(
if "use_cache" in kwargs:
warnings.warn(
- "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval` instead.",
+ "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems_eval`"
+ " instead.",
FutureWarning,
)
use_mems_eval = kwargs["use_cache"]
diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
index f6fc73ca0e585d..804b52b0dc8792 100755
--- a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
+++ b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py
@@ -88,8 +88,10 @@ def convert_xlnet_checkpoint_to_pytorch(
default=None,
type=str,
required=True,
- help="The config json file corresponding to the pre-trained XLNet model. \n"
- "This specifies the model architecture.",
+ help=(
+ "The config json file corresponding to the pre-trained XLNet model. \n"
+ "This specifies the model architecture."
+ ),
)
parser.add_argument(
"--pytorch_dump_folder_path",
diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py
index f5a1cba3c837f1..df4111d2631727 100644
--- a/src/transformers/models/xlnet/modeling_tf_xlnet.py
+++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py
@@ -1281,17 +1281,17 @@ def call(
>>> # We show how to setup inputs to predict a next token using a bi-directional context.
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is very ", add_special_tokens=True))[
... None, :
- >>> ] # We will predict the masked token
+ ... ] # We will predict the masked token
>>> perm_mask = np.zeros((1, input_ids.shape[1], input_ids.shape[1]))
>>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
>>> target_mapping = np.zeros(
... (1, 1, input_ids.shape[1])
- >>> ) # Shape [1, 1, seq_length] => let's predict one token
+ ... ) # Shape [1, 1, seq_length] => let's predict one token
>>> target_mapping[
... 0, 0, -1
- >>> ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
+ ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
>>> outputs = model(
... input_ids,
@@ -1301,7 +1301,7 @@ def call(
>>> next_token_logits = outputs[
... 0
- >>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
+ ... ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
```"""
transformer_outputs = self.transformer(
input_ids=input_ids,
diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py
index 079c636628f2fd..4a299a5a657f42 100755
--- a/src/transformers/models/xlnet/modeling_xlnet.py
+++ b/src/transformers/models/xlnet/modeling_xlnet.py
@@ -1056,7 +1056,6 @@ def relative_positional_encoding(self, qlen, klen, bsz=None):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
- pos_emb = pos_emb.to(self.device)
return pos_emb
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@@ -1092,7 +1091,8 @@ def forward(
if "use_cache" in kwargs:
warnings.warn(
- "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.",
+ "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems`"
+ " instead.",
FutureWarning,
)
use_mems = kwargs["use_cache"]
@@ -1205,6 +1205,7 @@ def forward(
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
+ pos_emb = pos_emb.to(output_h.device)
pos_emb = self.dropout(pos_emb)
# Prepare head mask if needed
@@ -1400,47 +1401,47 @@ def forward(
>>> # We show how to setup inputs to predict a next token using a bi-directional context.
>>> input_ids = torch.tensor(
... tokenizer.encode("Hello, my dog is very ", add_special_tokens=False)
- >>> ).unsqueeze(
+ ... ).unsqueeze(
... 0
- >>> ) # We will predict the masked token
+ ... ) # We will predict the masked token
>>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
>>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
>>> target_mapping = torch.zeros(
... (1, 1, input_ids.shape[1]), dtype=torch.float
- >>> ) # Shape [1, 1, seq_length] => let's predict one token
+ ... ) # Shape [1, 1, seq_length] => let's predict one token
>>> target_mapping[
... 0, 0, -1
- >>> ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
+ ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
>>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
>>> next_token_logits = outputs[
... 0
- >>> ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
+ ... ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
>>> # The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
>>> input_ids = torch.tensor(
... tokenizer.encode("Hello, my dog is very ", add_special_tokens=False)
- >>> ).unsqueeze(
+ ... ).unsqueeze(
... 0
- >>> ) # We will predict the masked token
+ ... ) # We will predict the masked token
>>> labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
>>> assert labels.shape[0] == 1, "only one word will be predicted"
>>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
>>> perm_mask[
... :, :, -1
- >>> ] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
+ ... ] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
>>> target_mapping = torch.zeros(
... (1, 1, input_ids.shape[1]), dtype=torch.float
- >>> ) # Shape [1, 1, seq_length] => let's predict one token
+ ... ) # Shape [1, 1, seq_length] => let's predict one token
>>> target_mapping[
... 0, 0, -1
- >>> ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
+ ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
>>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
>>> loss = outputs.loss
>>> next_token_logits = (
... outputs.logits
- >>> ) # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
+ ... ) # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1980,7 +1981,7 @@ def forward(
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
... 0
- >>> ) # Batch size 1
+ ... ) # Batch size 1
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
diff --git a/src/transformers/models/yolos/__init__.py b/src/transformers/models/yolos/__init__.py
new file mode 100644
index 00000000000000..91cc0e703213c0
--- /dev/null
+++ b/src/transformers/models/yolos/__init__.py
@@ -0,0 +1,75 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {"configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_yolos"] = [
+ "YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "YolosForObjectDetection",
+ "YolosModel",
+ "YolosPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_yolos import YolosFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_yolos import (
+ YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
+ YolosForObjectDetection,
+ YolosModel,
+ YolosPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/yolos/configuration_yolos.py b/src/transformers/models/yolos/configuration_yolos.py
new file mode 100644
index 00000000000000..cd3414a7f26eed
--- /dev/null
+++ b/src/transformers/models/yolos/configuration_yolos.py
@@ -0,0 +1,153 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" YOLOS model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "hustvl/yolos-small": "https://huggingface.co/hustvl/yolos-small/resolve/main/config.json",
+ # See all YOLOS models at https://huggingface.co/models?filter=yolos
+}
+
+
+class YolosConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`YolosModel`]. It is used to instantiate a YOLOS
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the YOLOS
+ [hustvl/yolos-base](https://huggingface.co/hustvl/yolos-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`List[int]`, *optional*, defaults to `[512, 864]`):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to `16`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to `3`):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ num_detection_tokens (`int`, *optional*, defaults to `100`):
+ The number of detection tokens.
+ use_mid_position_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to use the mid-layer position encodings.
+ auxiliary_loss (`bool`, *optional*, defaults to `False`):
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
+ class_cost (`float`, *optional*, defaults to 1):
+ Relative weight of the classification error in the Hungarian matching cost.
+ bbox_cost (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
+ giou_cost (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
+ Relative weight of the L1 bounding box loss in the object detection loss.
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
+ Relative weight of the generalized IoU loss in the object detection loss.
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
+ Relative classification weight of the 'no-object' class in the object detection loss.
+
+ Example:
+
+ ```python
+ >>> from transformers import YolosModel, YolosConfig
+
+ >>> # Initializing a YOLOS hustvl/yolos-base style configuration
+ >>> configuration = YolosConfig()
+
+ >>> # Initializing a model from the hustvl/yolos-base style configuration
+ >>> model = YolosModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "yolos"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=[512, 864],
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ num_detection_tokens=100,
+ use_mid_position_embeddings=True,
+ auxiliary_loss=False,
+ class_cost=1,
+ bbox_cost=5,
+ giou_cost=2,
+ bbox_loss_coefficient=5,
+ giou_loss_coefficient=2,
+ eos_coefficient=0.1,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.num_detection_tokens = num_detection_tokens
+ self.use_mid_position_embeddings = use_mid_position_embeddings
+ self.auxiliary_loss = auxiliary_loss
+ # Hungarian matcher
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ # Loss coefficients
+ self.bbox_loss_coefficient = bbox_loss_coefficient
+ self.giou_loss_coefficient = giou_loss_coefficient
+ self.eos_coefficient = eos_coefficient
diff --git a/src/transformers/models/yolos/convert_yolos_to_pytorch.py b/src/transformers/models/yolos/convert_yolos_to_pytorch.py
new file mode 100644
index 00000000000000..7f4161a632d89f
--- /dev/null
+++ b/src/transformers/models/yolos/convert_yolos_to_pytorch.py
@@ -0,0 +1,266 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert YOLOS checkpoints from the original repository. URL: https://github.com/hustvl/YOLOS"""
+
+
+import argparse
+import json
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+import requests
+from huggingface_hub import hf_hub_download
+from transformers import YolosConfig, YolosFeatureExtractor, YolosForObjectDetection
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_yolos_config(yolos_name):
+ config = YolosConfig()
+
+ # size of the architecture
+ if "yolos_ti" in yolos_name:
+ config.hidden_size = 192
+ config.intermediate_size = 768
+ config.num_hidden_layers = 12
+ config.num_attention_heads = 3
+ config.image_size = [800, 1333]
+ config.use_mid_position_embeddings = False
+ elif yolos_name == "yolos_s_dWr":
+ config.hidden_size = 330
+ config.num_hidden_layers = 14
+ config.num_attention_heads = 6
+ config.intermediate_size = 1320
+ elif "yolos_s" in yolos_name:
+ config.hidden_size = 384
+ config.intermediate_size = 1536
+ config.num_hidden_layers = 12
+ config.num_attention_heads = 6
+ elif "yolos_b" in yolos_name:
+ config.image_size = [800, 1344]
+
+ config.num_labels = 91
+ repo_id = "datasets/huggingface/label-files"
+ filename = "coco-detection-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ return config
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config, base_model=False):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def rename_key(name):
+ if "backbone" in name:
+ name = name.replace("backbone", "vit")
+ if "cls_token" in name:
+ name = name.replace("cls_token", "embeddings.cls_token")
+ if "det_token" in name:
+ name = name.replace("det_token", "embeddings.detection_tokens")
+ if "mid_pos_embed" in name:
+ name = name.replace("mid_pos_embed", "encoder.mid_position_embeddings")
+ if "pos_embed" in name:
+ name = name.replace("pos_embed", "embeddings.position_embeddings")
+ if "patch_embed.proj" in name:
+ name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
+ if "blocks" in name:
+ name = name.replace("blocks", "encoder.layer")
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "attn" in name:
+ name = name.replace("attn", "attention.self")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+ if "class_embed" in name:
+ name = name.replace("class_embed", "class_labels_classifier")
+ if "bbox_embed" in name:
+ name = name.replace("bbox_embed", "bbox_predictor")
+ if "vit.norm" in name:
+ name = name.replace("vit.norm", "vit.layernorm")
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, model):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if "qkv" in key:
+ key_split = key.split(".")
+ layer_num = int(key_split[2])
+ dim = model.vit.encoder.layer[layer_num].attention.attention.all_head_size
+ if "weight" in key:
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.weight"] = val[:dim, :]
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.weight"] = val[
+ dim : dim * 2, :
+ ]
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
+ else:
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.bias"] = val[:dim]
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
+ orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.bias"] = val[-dim:]
+ else:
+ orig_state_dict[rename_key(key)] = val
+
+ return orig_state_dict
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_yolos_checkpoint(yolos_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our YOLOS structure.
+ """
+ config = get_yolos_config(yolos_name)
+
+ # load original state_dict
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+
+ # load š¤ model
+ model = YolosForObjectDetection(config)
+ model.eval()
+ new_state_dict = convert_state_dict(state_dict, model)
+ model.load_state_dict(new_state_dict)
+
+ # Check outputs on an image, prepared by YolosFeatureExtractor
+ size = 800 if yolos_name != "yolos_ti" else 512
+ feature_extractor = YolosFeatureExtractor(format="coco_detection", size=size)
+ encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
+ outputs = model(**encoding)
+ logits, pred_boxes = outputs.logits, outputs.pred_boxes
+
+ expected_slice_logits, expected_slice_boxes = None, None
+ if yolos_name == "yolos_ti":
+ expected_slice_logits = torch.tensor(
+ [[-39.5022, -11.9820, -17.6888], [-29.9574, -9.9769, -17.7691], [-42.3281, -20.7200, -30.6294]]
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.4021, 0.0836, 0.7979], [0.0184, 0.2609, 0.0364], [0.1781, 0.2004, 0.2095]]
+ )
+ elif yolos_name == "yolos_s_200_pre":
+ expected_slice_logits = torch.tensor(
+ [[-24.0248, -10.3024, -14.8290], [-42.0392, -16.8200, -27.4334], [-27.2743, -11.8154, -18.7148]]
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.2559, 0.5455, 0.4706], [0.2989, 0.7279, 0.1875], [0.7732, 0.4017, 0.4462]]
+ )
+ elif yolos_name == "yolos_s_300_pre":
+ expected_slice_logits = torch.tensor(
+ [[-36.2220, -14.4385, -23.5457], [-35.6970, -14.7583, -21.3935], [-31.5939, -13.6042, -16.8049]]
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.7614, 0.2316, 0.4728], [0.7168, 0.4495, 0.3855], [0.4996, 0.1466, 0.9996]]
+ )
+ elif yolos_name == "yolos_s_dWr":
+ expected_slice_logits = torch.tensor(
+ [[-42.8668, -24.1049, -41.1690], [-34.7456, -14.1274, -24.9194], [-33.7898, -12.1946, -25.6495]]
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.5587, 0.2773, 0.0605], [0.5004, 0.3014, 0.9994], [0.4999, 0.1548, 0.9994]]
+ )
+ elif yolos_name == "yolos_base":
+ expected_slice_logits = torch.tensor(
+ [[-40.6064, -24.3084, -32.6447], [-55.1990, -30.7719, -35.5877], [-51.4311, -33.3507, -35.6462]]
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.5555, 0.2794, 0.0655], [0.9049, 0.2664, 0.1894], [0.9183, 0.1984, 0.1635]]
+ )
+ else:
+ raise ValueError(f"Unknown yolos_name: {yolos_name}")
+
+ assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
+ assert torch.allclose(pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)
+
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {yolos_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving feature extractor to {pytorch_dump_folder_path}")
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model_mapping = {
+ "yolos_ti": "yolos-tiny",
+ "yolos_s_200_pre": "yolos-small",
+ "yolos_s_300_pre": "yolos-small-300",
+ "yolos_s_dWr": "yolos-small-dwr",
+ "yolos_base": "yolos-base",
+ }
+
+ print("Pushing to the hub...")
+ model_name = model_mapping[yolos_name]
+ feature_extractor.push_to_hub(model_name, organization="hustvl")
+ model.push_to_hub(model_name, organization="hustvl")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--yolos_name",
+ default="yolos_s_200_pre",
+ type=str,
+ help=(
+ "Name of the YOLOS model you'd like to convert. Should be one of 'yolos_ti', 'yolos_s_200_pre',"
+ " 'yolos_s_300_pre', 'yolos_s_dWr', 'yolos_base'."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, help="Path to the original state dict (.pth file)."
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the š¤ hub."
+ )
+
+ args = parser.parse_args()
+ convert_yolos_checkpoint(args.yolos_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py
new file mode 100644
index 00000000000000..e199d1ae7bf463
--- /dev/null
+++ b/src/transformers/models/yolos/feature_extraction_yolos.py
@@ -0,0 +1,917 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for YOLOS."""
+
+import io
+import pathlib
+from collections import defaultdict
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+logger = logging.get_logger(__name__)
+
+
+ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
+
+
+# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
+def center_to_corners_format(x):
+ """
+ Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
+ (x_0, y_0, x_1, y_1).
+ """
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
+def corners_to_center_format(x):
+ """
+ Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
+ y_1) to center format (center_x, center_y, width, height).
+ """
+ x_transposed = x.T
+ x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return np.stack(b, axis=-1)
+
+
+# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
+def masks_to_boxes(masks):
+ """
+ Compute the bounding boxes around the provided panoptic segmentation masks.
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensor, with the boxes in corner (xyxy) format.
+ """
+ if masks.size == 0:
+ return np.zeros((0, 4))
+
+ h, w = masks.shape[-2:]
+
+ y = np.arange(0, h, dtype=np.float32)
+ x = np.arange(0, w, dtype=np.float32)
+ # see https://github.com/pytorch/pytorch/issues/50276
+ y, x = np.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * np.expand_dims(x, axis=0)
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
+ x_min = x.filled(fill_value=1e8)
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
+
+ y_mask = masks * np.expand_dims(y, axis=0)
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
+ y_min = y.filled(fill_value=1e8)
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
+
+ return np.stack([x_min, y_min, x_max, y_max], 1)
+
+
+# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
+def rgb_to_id(color):
+ if isinstance(color, np.ndarray) and len(color.shape) == 3:
+ if color.dtype == np.uint8:
+ color = color.astype(np.int32)
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
+
+
+# Copied from transformers.models.detr.feature_extraction_detr.id_to_rgb
+def id_to_rgb(id_map):
+ if isinstance(id_map, np.ndarray):
+ id_map_copy = id_map.copy()
+ rgb_shape = tuple(list(id_map.shape) + [3])
+ rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
+ for i in range(3):
+ rgb_map[..., i] = id_map_copy % 256
+ id_map_copy //= 256
+ return rgb_map
+ color = []
+ for _ in range(3):
+ color.append(id_map % 256)
+ id_map //= 256
+ return color
+
+
+class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a YOLOS feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+
+ Args:
+ format (`str`, *optional*, defaults to `"coco_detection"`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int`, *optional*, defaults to 800):
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
+ sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
+ the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
+ height / width, size)`.
+ max_size (`int`, *optional*, defaults to `1333`):
+ The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
+ set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
+ image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
+ ImageNet std.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.__init__
+ def __init__(
+ self,
+ format="coco_detection",
+ do_resize=True,
+ size=800,
+ max_size=1333,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.format = self._is_valid_format(format)
+ self.do_resize = do_resize
+ self.size = size
+ self.max_size = max_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
+ self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._is_valid_format
+ def _is_valid_format(self, format):
+ if format not in ["coco_detection", "coco_panoptic"]:
+ raise ValueError(f"Format {format} not supported")
+ return format
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare
+ def prepare(self, image, target, return_segmentation_masks=False, masks_path=None):
+ if self.format == "coco_detection":
+ image, target = self.prepare_coco_detection(image, target, return_segmentation_masks)
+ return image, target
+ elif self.format == "coco_panoptic":
+ image, target = self.prepare_coco_panoptic(image, target, masks_path)
+ return image, target
+ else:
+ raise ValueError(f"Format {self.format} not supported")
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.convert_coco_poly_to_mask
+ def convert_coco_poly_to_mask(self, segmentations, height, width):
+
+ try:
+ from pycocotools import mask as coco_mask
+ except ImportError:
+ raise ImportError("Pycocotools is not installed in your environment.")
+
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = np.asarray(mask, dtype=np.uint8)
+ mask = np.any(mask, axis=2)
+ masks.append(mask)
+ if masks:
+ masks = np.stack(masks, axis=0)
+ else:
+ masks = np.zeros((0, height, width), dtype=np.uint8)
+
+ return masks
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare_coco_detection
+ def prepare_coco_detection(self, image, target, return_segmentation_masks=False):
+ """
+ Convert the target in COCO format into the format expected by DETR.
+ """
+ w, h = image.size
+
+ image_id = target["image_id"]
+ image_id = np.asarray([image_id], dtype=np.int64)
+
+ # get all COCO annotations for the given image
+ anno = target["annotations"]
+
+ anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
+
+ boxes = [obj["bbox"] for obj in anno]
+ # guard against no boxes via resizing
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=w)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=h)
+
+ classes = [obj["category_id"] for obj in anno]
+ classes = np.asarray(classes, dtype=np.int64)
+
+ if return_segmentation_masks:
+ segmentations = [obj["segmentation"] for obj in anno]
+ masks = self.convert_coco_poly_to_mask(segmentations, h, w)
+
+ keypoints = None
+ if anno and "keypoints" in anno[0]:
+ keypoints = [obj["keypoints"] for obj in anno]
+ keypoints = np.asarray(keypoints, dtype=np.float32)
+ num_keypoints = keypoints.shape[0]
+ if num_keypoints:
+ keypoints = keypoints.reshape((-1, 3))
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+ if return_segmentation_masks:
+ masks = masks[keep]
+ if keypoints is not None:
+ keypoints = keypoints[keep]
+
+ target = {}
+ target["boxes"] = boxes
+ target["class_labels"] = classes
+ if return_segmentation_masks:
+ target["masks"] = masks
+ target["image_id"] = image_id
+ if keypoints is not None:
+ target["keypoints"] = keypoints
+
+ # for conversion to coco api
+ area = np.asarray([obj["area"] for obj in anno], dtype=np.float32)
+ iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno], dtype=np.int64)
+ target["area"] = area[keep]
+ target["iscrowd"] = iscrowd[keep]
+
+ target["orig_size"] = np.asarray([int(h), int(w)], dtype=np.int64)
+ target["size"] = np.asarray([int(h), int(w)], dtype=np.int64)
+
+ return image, target
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare_coco_panoptic
+ def prepare_coco_panoptic(self, image, target, masks_path, return_masks=True):
+ w, h = image.size
+ ann_info = target.copy()
+ ann_path = pathlib.Path(masks_path) / ann_info["file_name"]
+
+ if "segments_info" in ann_info:
+ masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
+ masks = rgb_to_id(masks)
+
+ ids = np.array([ann["id"] for ann in ann_info["segments_info"]])
+ masks = masks == ids[:, None, None]
+ masks = np.asarray(masks, dtype=np.uint8)
+
+ labels = np.asarray([ann["category_id"] for ann in ann_info["segments_info"]], dtype=np.int64)
+
+ target = {}
+ target["image_id"] = np.asarray(
+ [ann_info["image_id"] if "image_id" in ann_info else ann_info["id"]], dtype=np.int64
+ )
+ if return_masks:
+ target["masks"] = masks
+ target["class_labels"] = labels
+
+ target["boxes"] = masks_to_boxes(masks)
+
+ target["size"] = np.asarray([int(h), int(w)], dtype=np.int64)
+ target["orig_size"] = np.asarray([int(h), int(w)], dtype=np.int64)
+ if "segments_info" in ann_info:
+ target["iscrowd"] = np.asarray([ann["iscrowd"] for ann in ann_info["segments_info"]], dtype=np.int64)
+ target["area"] = np.asarray([ann["area"] for ann in ann_info["segments_info"]], dtype=np.float32)
+
+ return image, target
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._resize
+ def _resize(self, image, size, target=None, max_size=None):
+ """
+ Resize the image to the given size. Size can be min_size (scalar) or (w, h) tuple. If size is an int, smaller
+ edge of the image will be matched to this number.
+
+ If given, also resize the target accordingly.
+ """
+ if not isinstance(image, Image.Image):
+ image = self.to_pil_image(image)
+
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(round(max_size * min_original_size / max_original_size))
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w)
+
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+
+ return (oh, ow)
+
+ def get_size(image_size, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size
+ else:
+ # size returned must be (w, h) since we use PIL to resize images
+ # so we revert the tuple
+ return get_size_with_aspect_ratio(image_size, size, max_size)[::-1]
+
+ size = get_size(image.size, size, max_size)
+ rescaled_image = self.resize(image, size=size)
+
+ if target is None:
+ return rescaled_image, None
+
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
+ ratio_width, ratio_height = ratios
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
+ target["boxes"] = scaled_boxes
+
+ if "area" in target:
+ area = target["area"]
+ scaled_area = area * (ratio_width * ratio_height)
+ target["area"] = scaled_area
+
+ w, h = size
+ target["size"] = np.asarray([h, w], dtype=np.int64)
+
+ if "masks" in target:
+ # use PyTorch as current workaround
+ # TODO replace by self.resize
+ masks = torch.from_numpy(target["masks"][:, None]).float()
+ interpolated_masks = nn.functional.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5
+ target["masks"] = interpolated_masks.numpy()
+
+ return rescaled_image, target
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._normalize
+ def _normalize(self, image, mean, std, target=None):
+ """
+ Normalize the image with a certain mean and std.
+
+ If given, also normalize the target bounding boxes based on the size of the image.
+ """
+
+ image = self.normalize(image, mean=mean, std=std)
+ if target is None:
+ return image, None
+
+ target = target.copy()
+ h, w = image.shape[-2:]
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = corners_to_center_format(boxes)
+ boxes = boxes / np.asarray([w, h, w, h], dtype=np.float32)
+ target["boxes"] = boxes
+
+ return image, target
+
+ def __call__(
+ self,
+ images: ImageInput,
+ annotations: Union[List[Dict], List[List[Dict]]] = None,
+ return_segmentation_masks: Optional[bool] = False,
+ masks_path: Optional[pathlib.Path] = None,
+ padding: Optional[bool] = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s) and optional annotations. Images are by default
+ padded up to the largest image in a batch.
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ annotations (`Dict`, `List[Dict]`, *optional*):
+ The corresponding annotations in COCO format.
+
+ In case [`DetrFeatureExtractor`] was initialized with `format = "coco_detection"`, the annotations for
+ each image should have the following format: {'image_id': int, 'annotations': [annotation]}, with the
+ annotations being a list of COCO object annotations.
+
+ In case [`DetrFeatureExtractor`] was initialized with `format = "coco_panoptic"`, the annotations for
+ each image should have the following format: {'image_id': int, 'file_name': str, 'segments_info':
+ [segment_info]} with segments_info being a list of COCO panoptic annotations.
+
+ return_segmentation_masks (`Dict`, `List[Dict]`, *optional*, defaults to `False`):
+ Whether to also include instance segmentation masks as part of the labels in case `format =
+ "coco_detection"`.
+
+ masks_path (`pathlib.Path`, *optional*):
+ Path to the directory containing the PNG files that store the class-agnostic image segmentations. Only
+ relevant in case [`DetrFeatureExtractor`] was initialized with `format = "coco_panoptic"`.
+
+ padding (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad images up to the largest image in a batch.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
+ objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ - **labels** -- Optional labels to be fed to a model (when `annotations` are provided)
+ """
+ # Input type checking for clearer error
+
+ valid_images = False
+ valid_annotations = False
+ valid_masks_path = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ # Check that annotations has a valid type
+ if annotations is not None:
+ if not is_batched:
+ if self.format == "coco_detection":
+ if isinstance(annotations, dict) and "image_id" in annotations and "annotations" in annotations:
+ if isinstance(annotations["annotations"], (list, tuple)):
+ # an image can have no annotations
+ if len(annotations["annotations"]) == 0 or isinstance(annotations["annotations"][0], dict):
+ valid_annotations = True
+ elif self.format == "coco_panoptic":
+ if isinstance(annotations, dict) and "image_id" in annotations and "segments_info" in annotations:
+ if isinstance(annotations["segments_info"], (list, tuple)):
+ # an image can have no segments (?)
+ if len(annotations["segments_info"]) == 0 or isinstance(
+ annotations["segments_info"][0], dict
+ ):
+ valid_annotations = True
+ else:
+ if isinstance(annotations, (list, tuple)):
+ if len(images) != len(annotations):
+ raise ValueError("There must be as many annotations as there are images")
+ if isinstance(annotations[0], Dict):
+ if self.format == "coco_detection":
+ if isinstance(annotations[0]["annotations"], (list, tuple)):
+ valid_annotations = True
+ elif self.format == "coco_panoptic":
+ if isinstance(annotations[0]["segments_info"], (list, tuple)):
+ valid_annotations = True
+
+ if not valid_annotations:
+ raise ValueError(
+ """
+ Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images). In case of object
+ detection, each dictionary should contain the keys 'image_id' and 'annotations', with the latter
+ being a list of annotations in COCO format. In case of panoptic segmentation, each dictionary
+ should contain the keys 'file_name', 'image_id' and 'segments_info', with the latter being a list
+ of annotations in COCO format.
+ """
+ )
+
+ # Check that masks_path has a valid type
+ if masks_path is not None:
+ if self.format == "coco_panoptic":
+ if isinstance(masks_path, pathlib.Path):
+ valid_masks_path = True
+ if not valid_masks_path:
+ raise ValueError(
+ "The path to the directory containing the mask PNG files should be provided as a"
+ " `pathlib.Path` object."
+ )
+
+ if not is_batched:
+ images = [images]
+ if annotations is not None:
+ annotations = [annotations]
+
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ for idx, (image, target) in enumerate(zip(images, annotations)):
+ if not isinstance(image, Image.Image):
+ image = self.to_pil_image(image)
+ image, target = self.prepare(image, target, return_segmentation_masks, masks_path)
+ images[idx] = image
+ annotations[idx] = target
+
+ # transformations (resizing + normalization)
+ if self.do_resize and self.size is not None:
+ if annotations is not None:
+ for idx, (image, target) in enumerate(zip(images, annotations)):
+ image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size)
+ images[idx] = image
+ annotations[idx] = target
+ else:
+ for idx, image in enumerate(images):
+ images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]
+
+ if self.do_normalize:
+ if annotations is not None:
+ for idx, (image, target) in enumerate(zip(images, annotations)):
+ image, target = self._normalize(
+ image=image, mean=self.image_mean, std=self.image_std, target=target
+ )
+ images[idx] = image
+ annotations[idx] = target
+ else:
+ images = [
+ self._normalize(image=image, mean=self.image_mean, std=self.image_std)[0] for image in images
+ ]
+
+ if padding:
+ # pad images up to largest image in batch
+ max_size = self._max_by_axis([list(image.shape) for image in images])
+ c, h, w = max_size
+ padded_images = []
+ for image in images:
+ # create padded image
+ padded_image = np.zeros((c, h, w), dtype=np.float32)
+ padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
+ padded_images.append(padded_image)
+ images = padded_images
+
+ # return as BatchFeature
+ data = {}
+ data["pixel_values"] = images
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if annotations is not None:
+ # Convert to TensorType
+ tensor_type = return_tensors
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ if not tensor_type == TensorType.PYTORCH:
+ raise ValueError("Only PyTorch is supported for the moment.")
+ else:
+ if not is_torch_available():
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+
+ encoded_inputs["labels"] = [
+ {k: torch.from_numpy(v) for k, v in target.items()} for target in annotations
+ ]
+
+ return encoded_inputs
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._max_by_axis
+ def _max_by_axis(self, the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+ def pad(self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None):
+ """
+ Pad images up to the largest image in a batch.
+
+ Args:
+ pixel_values_list (`List[torch.Tensor]`):
+ List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
+ objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following field:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+
+ """
+
+ max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
+ c, h, w = max_size
+ padded_images = []
+ for image in pixel_values_list:
+ # create padded image
+ padded_image = np.zeros((c, h, w), dtype=np.float32)
+ padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
+ padded_images.append(padded_image)
+
+ # return as BatchFeature
+ data = {"pixel_values": padded_images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports
+ PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
+ image size (before any data augmentation). For visualization, this should be the image size after data
+ augment, but before padding.
+
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if len(out_logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ prob = nn.functional.softmax(out_logits, -1)
+ scores, labels = prob[..., :-1].max(-1)
+
+ # convert to [x0, y0, x1, y1] format
+ boxes = center_to_corners_format(out_bbox)
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_segmentation
+ def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
+ """
+ Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
+
+ Parameters:
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
+ threshold (`float`, *optional*, defaults to 0.9):
+ Threshold to use to filter out queries.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, raw_masks = outputs.logits, outputs.pred_masks
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.cpu().tolist())
+
+ for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
+ # we filter empty queries and detection below threshold
+ scores, labels = cur_logits.softmax(-1).max(-1)
+ keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
+ cur_scores = cur_scores[keep]
+ cur_classes = cur_classes[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
+ cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
+
+ predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
+ preds.append(predictions)
+ return preds
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance
+ def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
+ """
+ Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
+ PyTorch.
+
+ Args:
+ results (`List[Dict]`):
+ Results list obtained by [`~DetrFeatureExtractor.post_process`], to which "masks" results will be
+ added.
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
+ image size (before any data augmentation).
+ max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation).
+ threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
+ image in the batch as predicted by the model.
+ """
+
+ if len(orig_target_sizes) != len(max_target_sizes):
+ raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
+ outputs_masks = outputs.pred_masks.squeeze(2)
+ outputs_masks = nn.functional.interpolate(
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
+ )
+ outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
+
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
+ img_h, img_w = t[0], t[1]
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
+ results[i]["masks"] = nn.functional.interpolate(
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
+ ).byte()
+
+ return results
+
+ # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_panoptic
+ def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
+ """
+ Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
+
+ Parameters:
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
+ Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
+ augmentation but before batching.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
+ None, it will default to the `processed_sizes`.
+ is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
+ Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
+ If not set, defaults to the `is_thing_map` of COCO panoptic.
+ threshold (`float`, *optional*, defaults to 0.85):
+ Threshold to use to filter out queries.
+
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
+ an image in the batch as predicted by the model.
+ """
+ if target_sizes is None:
+ target_sizes = processed_sizes
+ if len(processed_sizes) != len(target_sizes):
+ raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
+
+ if is_thing_map is None:
+ # default to is_thing_map of COCO panoptic
+ is_thing_map = {i: i <= 90 for i in range(201)}
+
+ out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
+ if not len(out_logits) == len(raw_masks) == len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
+ )
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.cpu().tolist())
+
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
+ ):
+ # we filter empty queries and detection below threshold
+ scores, labels = cur_logits.softmax(-1).max(-1)
+ keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
+ cur_scores = cur_scores[keep]
+ cur_classes = cur_classes[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
+ cur_boxes = center_to_corners_format(cur_boxes[keep])
+
+ h, w = cur_masks.shape[-2:]
+ if len(cur_boxes) != len(cur_classes):
+ raise ValueError("Not as many boxes as there are classes")
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.flatten(1)
+ stuff_equiv_classes = defaultdict(lambda: [])
+ for k, label in enumerate(cur_classes):
+ if not is_thing_map[label.item()]:
+ stuff_equiv_classes[label.item()].append(k)
+
+ def get_ids_area(masks, scores, dedup=False):
+ # This helper function creates the final panoptic segmentation image
+ # It also returns the area of the masks that appears on the image
+
+ m_id = masks.transpose(0, 1).softmax(-1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
+ else:
+ m_id = m_id.argmax(-1).view(h, w)
+
+ if dedup:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ if len(equiv) > 1:
+ for eq_id in equiv:
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
+
+ final_h, final_w = to_tuple(target_size)
+
+ seg_img = Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
+
+ np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
+ np_seg_img = np_seg_img.view(final_h, final_w, 3)
+ np_seg_img = np_seg_img.numpy()
+
+ m_id = torch.from_numpy(rgb_to_id(np_seg_img))
+
+ area = []
+ for i in range(len(scores)):
+ area.append(m_id.eq(i).sum().item())
+ return area, seg_img
+
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
+ if cur_classes.numel() > 0:
+ # We know filter empty masks as long as we find some
+ while True:
+ filtered_small = torch.as_tensor(
+ [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
+ )
+ if filtered_small.any().item():
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ cur_masks = cur_masks[~filtered_small]
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
+ else:
+ break
+
+ else:
+ cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
+
+ segments_info = []
+ for i, a in enumerate(area):
+ cat = cur_classes[i].item()
+ segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
+ del cur_classes
+
+ with io.BytesIO() as out:
+ seg_img.save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+ preds.append(predictions)
+ return preds
diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py
new file mode 100755
index 00000000000000..578e8ca6092794
--- /dev/null
+++ b/src/transformers/models/yolos/modeling_yolos.py
@@ -0,0 +1,1324 @@
+# coding=utf-8
+# Copyright 2022 School of EIC, Huazhong University of Science & Technology and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch YOLOS model."""
+
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import Tensor, nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_scipy_available,
+ is_vision_available,
+ logging,
+ replace_return_docstrings,
+ requires_backends,
+)
+from .configuration_yolos import YolosConfig
+
+
+if is_scipy_available():
+ from scipy.optimize import linear_sum_assignment
+
+if is_vision_available():
+ from transformers.models.detr.feature_extraction_detr import center_to_corners_format
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "YolosConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "YolosFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "hustvl/yolos-small"
+_EXPECTED_OUTPUT_SHAPE = [1, 3401, 384]
+
+
+YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "hustvl/yolos-small",
+ # See all YOLOS models at https://huggingface.co/models?filter=yolos
+]
+
+
+@dataclass
+class YolosObjectDetectionOutput(ModelOutput):
+ """
+ Output type of [`YolosForObjectDetection`].
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
+ scale-invariant IoU loss.
+ loss_dict (`Dict`, *optional*):
+ A dictionary containing the individual losses. Useful for logging.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
+ Classification logits (including no-object) for all queries.
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
+ possible padding). You can use [`~DetrFeatureExtractor.post_process`] to retrieve the unnormalized bounding
+ boxes.
+ auxiliary_outputs (`list[Dict]`, *optional*):
+ Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
+ `pred_boxes`) for each decoder layer.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
+ the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ loss_dict: Optional[Dict] = None
+ logits: torch.FloatTensor = None
+ pred_boxes: torch.FloatTensor = None
+ auxiliary_outputs: Optional[List[Dict]] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.vit.modeling_vit.to_2tuple
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+class YolosEmbeddings(nn.Module):
+ """
+ Construct the CLS token, detection tokens, position and patch embeddings.
+
+ """
+
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
+ self.patch_embeddings = PatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
+ )
+
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.interpolation = InterpolateInitialPositionEmbeddings(config)
+ self.config = config
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values)
+
+ batch_size, seq_len, _ = embeddings.size()
+
+ # add the [CLS] and detection tokens to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ detection_tokens = self.detection_tokens.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings, detection_tokens), dim=1)
+
+ # add positional encoding to each token
+ # this might require interpolation of the existing position embeddings
+ position_embeddings = self.interpolation(self.position_embeddings, (height, width))
+
+ embeddings = embeddings + position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class InterpolateInitialPositionEmbeddings(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.config = config
+
+ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
+ cls_pos_embed = pos_embed[:, 0, :]
+ cls_pos_embed = cls_pos_embed[:, None]
+ det_pos_embed = pos_embed[:, -self.config.num_detection_tokens :, :]
+ patch_pos_embed = pos_embed[:, 1 : -self.config.num_detection_tokens, :]
+ patch_pos_embed = patch_pos_embed.transpose(1, 2)
+ batch_size, hidden_size, seq_len = patch_pos_embed.shape
+
+ patch_height, patch_width = (
+ self.config.image_size[0] // self.config.patch_size,
+ self.config.image_size[1] // self.config.patch_size,
+ )
+ patch_pos_embed = patch_pos_embed.view(batch_size, hidden_size, patch_height, patch_width)
+
+ height, width = img_size
+ new_patch_heigth, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed, size=(new_patch_heigth, new_patch_width), mode="bicubic", align_corners=False
+ )
+ patch_pos_embed = patch_pos_embed.flatten(2).transpose(1, 2)
+ scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=1)
+ return scale_pos_embed
+
+
+class InterpolateMidPositionEmbeddings(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.config = config
+
+ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
+ cls_pos_embed = pos_embed[:, :, 0, :]
+ cls_pos_embed = cls_pos_embed[:, None]
+ det_pos_embed = pos_embed[:, :, -self.config.num_detection_tokens :, :]
+ patch_pos_embed = pos_embed[:, :, 1 : -self.config.num_detection_tokens, :]
+ patch_pos_embed = patch_pos_embed.transpose(2, 3)
+ depth, batch_size, hidden_size, seq_len = patch_pos_embed.shape
+
+ patch_height, patch_width = (
+ self.config.image_size[0] // self.config.patch_size,
+ self.config.image_size[1] // self.config.patch_size,
+ )
+ patch_pos_embed = patch_pos_embed.view(depth * batch_size, hidden_size, patch_height, patch_width)
+ height, width = img_size
+ new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
+ )
+ patch_pos_embed = (
+ patch_pos_embed.flatten(2)
+ .transpose(1, 2)
+ .contiguous()
+ .view(depth, batch_size, new_patch_height * new_patch_width, hidden_size)
+ )
+ scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=2)
+ return scale_pos_embed
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+class PatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+
+ """
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ num_channels: int = 3,
+ embed_dim: int = 768,
+ ):
+ super().__init__()
+ image_size = to_2tuple(image_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos
+class YolosSelfAttention(nn.Module):
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
+class YolosSelfOutput(nn.Module):
+ """
+ The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Yolos
+class YolosAttention(nn.Module):
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.attention = YolosSelfAttention(config)
+ self.output = YolosSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
+class YolosIntermediate(nn.Module):
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Yolos
+class YolosOutput(nn.Module):
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos
+class YolosLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = YolosAttention(config)
+ self.intermediate = YolosIntermediate(config)
+ self.output = YolosOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in Yolos, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in Yolos, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class YolosEncoder(nn.Module):
+ def __init__(self, config: YolosConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([YolosLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ seq_length = (
+ 1 + (config.image_size[0] * config.image_size[1] // config.patch_size**2) + config.num_detection_tokens
+ )
+ self.mid_position_embeddings = (
+ nn.Parameter(
+ torch.zeros(
+ config.num_hidden_layers - 1,
+ 1,
+ seq_length,
+ config.hidden_size,
+ )
+ )
+ if config.use_mid_position_embeddings
+ else None
+ )
+
+ self.interpolation = InterpolateMidPositionEmbeddings(config) if config.use_mid_position_embeddings else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ height,
+ width,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if self.config.use_mid_position_embeddings:
+ interpolated_mid_position_embeddings = self.interpolation(self.mid_position_embeddings, (height, width))
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if self.config.use_mid_position_embeddings:
+ if i < (self.config.num_hidden_layers - 1):
+ hidden_states = hidden_states + interpolated_mid_position_embeddings[i]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class YolosPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = YolosConfig
+ base_model_prefix = "vit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None:
+ if isinstance(module, YolosEncoder):
+ module.gradient_checkpointing = value
+
+
+YOLOS_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`YolosConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+YOLOS_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare YOLOS Model transformer outputting raw hidden-states without any specific head on top.",
+ YOLOS_START_DOCSTRING,
+)
+class YolosModel(YolosPreTrainedModel):
+ def __init__(self, config: YolosConfig, add_pooling_layer: bool = True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = YolosEmbeddings(config)
+ self.encoder = YolosEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = YolosPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model.
+
+ Args:
+ heads_to_prune (`dict` of {layer_num: list of heads to prune in this layer}):
+ See base class `PreTrainedModel`.
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(YOLOS_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ height=pixel_values.shape[-2],
+ width=pixel_values.shape[-1],
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class YolosPooler(nn.Module):
+ def __init__(self, config: YolosConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@add_start_docstrings(
+ """
+ YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
+ """,
+ YOLOS_START_DOCSTRING,
+)
+class YolosForObjectDetection(YolosPreTrainedModel):
+ def __init__(self, config: YolosConfig):
+ super().__init__(config)
+
+ # YOLOS (ViT) encoder model
+ self.vit = YolosModel(config, add_pooling_layer=False)
+
+ # Object detection heads
+ # We add one for the "no object" class
+ self.class_labels_classifier = YolosMLPPredictionHead(
+ input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=config.num_labels + 1, num_layers=3
+ )
+ self.bbox_predictor = YolosMLPPredictionHead(
+ input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=4, num_layers=3
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+ @add_start_docstrings_to_model_forward(YOLOS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=YolosObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`List[Dict]` of len `(batch_size,)`, *optional*):
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
+ following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
+ batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
+ boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
+ 4)`.
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import YolosFeatureExtractor, YolosForObjectDetection
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
+ >>> model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> # model predicts bounding boxes and corresponding COCO classes
+ >>> logits = outputs.logits
+ >>> bboxes = outputs.pred_boxes
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # First, sent images through YOLOS base model to obtain hidden states
+ outputs = self.vit(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ # Take the final hidden states of the detection tokens
+ sequence_output = sequence_output[:, -self.config.num_detection_tokens :, :]
+
+ # Class logits + predicted bounding boxes
+ logits = self.class_labels_classifier(sequence_output)
+ pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
+
+ loss, loss_dict, auxiliary_outputs = None, None, None
+ if labels is not None:
+ # First: create the matcher
+ matcher = YolosHungarianMatcher(
+ class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
+ )
+ # Second: create the criterion
+ losses = ["labels", "boxes", "cardinality"]
+ criterion = YolosLoss(
+ matcher=matcher,
+ num_classes=self.config.num_labels,
+ eos_coef=self.config.eos_coefficient,
+ losses=losses,
+ )
+ criterion.to(self.device)
+ # Third: compute the losses, based on outputs and labels
+ outputs_loss = {}
+ outputs_loss["logits"] = logits
+ outputs_loss["pred_boxes"] = pred_boxes
+ if self.config.auxiliary_loss:
+ intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
+ outputs_class = self.class_labels_classifier(intermediate)
+ outputs_coord = self.bbox_predictor(intermediate).sigmoid()
+ auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
+ outputs_loss["auxiliary_outputs"] = auxiliary_outputs
+
+ loss_dict = criterion(outputs_loss, labels)
+ # Fourth: compute total loss, as a weighted sum of the various losses
+ weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
+ weight_dict["loss_giou"] = self.config.giou_loss_coefficient
+ if self.config.auxiliary_loss:
+ aux_weight_dict = {}
+ for i in range(self.config.decoder_layers - 1):
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+ loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
+
+ if not return_dict:
+ if auxiliary_outputs is not None:
+ output = (logits, pred_boxes) + auxiliary_outputs + outputs
+ else:
+ output = (logits, pred_boxes) + outputs
+ return ((loss, loss_dict) + output) if loss is not None else output
+
+ return YolosObjectDetectionOutput(
+ loss=loss,
+ loss_dict=loss_dict,
+ logits=logits,
+ pred_boxes=pred_boxes,
+ auxiliary_outputs=auxiliary_outputs,
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.detr.modeling_detr.dice_loss
+def dice_loss(inputs, targets, num_boxes):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs (0 for the negative class and 1 for the positive
+ class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_boxes
+
+
+# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs (0 for the negative class and 1 for the positive
+ class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.mean(1).sum() / num_boxes
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrLoss with Detr->Yolos
+class YolosLoss(nn.Module):
+ """
+ This class computes the losses for YolosForObjectDetection/YolosForSegmentation. The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each
+ pair of matched ground-truth / prediction (supervise class and box).
+
+ A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
+ parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
+ the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
+ be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
+ (`max_obj_id` + 1). For more details on this, check the following discussion
+ https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
+
+
+ Args:
+ matcher (`YolosHungarianMatcher`):
+ Module able to compute a matching between targets and proposals.
+ num_classes (`int`):
+ Number of object categories, omitting the special no-object category.
+ eos_coef (`float`):
+ Relative classification weight applied to the no-object category.
+ losses (`List[str]`):
+ List of all the losses to be applied. See `get_loss` for a list of all available losses.
+ """
+
+ def __init__(self, matcher, num_classes, eos_coef, losses):
+ super().__init__()
+ self.matcher = matcher
+ self.num_classes = num_classes
+ self.eos_coef = eos_coef
+ self.losses = losses
+ empty_weight = torch.ones(self.num_classes + 1)
+ empty_weight[-1] = self.eos_coef
+ self.register_buffer("empty_weight", empty_weight)
+
+ # removed logging parameter, which was part of the original implementation
+ def loss_labels(self, outputs, targets, indices, num_boxes):
+ """
+ Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
+ [nb_target_boxes]
+ """
+ if "logits" not in outputs:
+ raise KeyError("No logits were found in the outputs")
+ src_logits = outputs["logits"]
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+ )
+ target_classes[idx] = target_classes_o
+
+ loss_ce = nn.functional.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+ losses = {"loss_ce": loss_ce}
+
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
+ """
+ Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
+
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
+ """
+ logits = outputs["logits"]
+ device = logits.device
+ tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
+ card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {"cardinality_error": card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
+ """
+ Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
+
+ Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
+ are expected in format (center_x, center_y, w, h), normalized by the image size.
+ """
+ if "pred_boxes" not in outputs:
+ raise KeyError("No predicted boxes found in outputs")
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs["pred_boxes"][idx]
+ target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+ loss_bbox = nn.functional.l1_loss(src_boxes, target_boxes, reduction="none")
+
+ losses = {}
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(
+ generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
+ )
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """
+ Compute the losses related to the masks: the focal loss and the dice loss.
+
+ Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
+ """
+ if "pred_masks" not in outputs:
+ raise KeyError("No predicted masks found in outputs")
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+ src_masks = outputs["pred_masks"]
+ src_masks = src_masks[src_idx]
+ masks = [t["masks"] for t in targets]
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+ target_masks = target_masks.to(src_masks)
+ target_masks = target_masks[tgt_idx]
+
+ # upsample predictions to the target size
+ src_masks = nn.functional.interpolate(
+ src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
+ )
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks.flatten(1)
+ target_masks = target_masks.view(src_masks.shape)
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes):
+ loss_map = {
+ "labels": self.loss_labels,
+ "cardinality": self.loss_cardinality,
+ "boxes": self.loss_boxes,
+ "masks": self.loss_masks,
+ }
+ if loss not in loss_map:
+ raise ValueError(f"Loss {loss} not supported")
+ return loss_map[loss](outputs, targets, indices, num_boxes)
+
+ def forward(self, outputs, targets):
+ """
+ This performs the loss computation.
+
+ Args:
+ outputs (`dict`, *optional*):
+ Dictionary of tensors, see the output specification of the model for the format.
+ targets (`List[dict]`, *optional*):
+ List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
+ losses applied, see each loss' doc.
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["class_labels"]) for t in targets)
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+ # (Niels): comment out function below, distributed training to be added
+ # if is_dist_avail_and_initialized():
+ # torch.distributed.all_reduce(num_boxes)
+ # (Niels) in original implementation, num_boxes is divided by get_world_size()
+ num_boxes = torch.clamp(num_boxes, min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "auxiliary_outputs" in outputs:
+ for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
+ indices = self.matcher(auxiliary_outputs, targets)
+ for loss in self.losses:
+ if loss == "masks":
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos
+class YolosMLPPredictionHead(nn.Module):
+ """
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
+ height and width of a bounding box w.r.t. an image.
+
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+
+ """
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher with Detr->Yolos
+class YolosHungarianMatcher(nn.Module):
+ """
+ This class computes an assignment between the targets and the predictions of the network.
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
+ predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
+ un-matched (and thus treated as non-objects).
+
+ Args:
+ class_cost:
+ The relative weight of the classification error in the matching cost.
+ bbox_cost:
+ The relative weight of the L1 error of the bounding box coordinates in the matching cost.
+ giou_cost:
+ The relative weight of the giou loss of the bounding box in the matching cost.
+ """
+
+ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
+ super().__init__()
+ requires_backends(self, ["scipy"])
+
+ self.class_cost = class_cost
+ self.bbox_cost = bbox_cost
+ self.giou_cost = giou_cost
+ if class_cost == 0 or bbox_cost == 0 or giou_cost == 0:
+ raise ValueError("All costs of the Matcher can't be 0")
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ """
+ Args:
+ outputs (`dict`):
+ A dictionary that contains at least these entries:
+ * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
+ targets (`List[dict]`):
+ A list of targets (len(targets) = batch_size), where each target is a dict containing:
+ * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
+ ground-truth
+ objects in the target) containing the class labels
+ * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
+
+ Returns:
+ `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ batch_size, num_queries = outputs["logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["class_labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ class_cost = -out_prob[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost between boxes
+ giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox))
+
+ # Final cost matrix
+ cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
+
+ sizes = [len(v["boxes"]) for v in targets]
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+# Copied from transformers.models.detr.modeling_detr._upcast
+def _upcast(t: Tensor) -> Tensor:
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
+ if t.is_floating_point():
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
+ else:
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
+
+
+# Copied from transformers.models.detr.modeling_detr.box_area
+def box_area(boxes: Tensor) -> Tensor:
+ """
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
+
+ Args:
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
+ < x2` and `0 <= y1 < y2`.
+
+ Returns:
+ `torch.FloatTensor`: a tensor containing the area for each box.
+ """
+ boxes = _upcast(boxes)
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+# Copied from transformers.models.detr.modeling_detr.box_iou
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
+
+ Returns:
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / area
+
+
+# Copied from transformers.models.detr.modeling_detr._max_by_axis
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+# Copied from transformers.models.detr.modeling_detr.NestedTensor
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ if tensor_list[0].ndim == 3:
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("Only 3-dimensional tensors are supported")
+ return NestedTensor(tensor, mask)
diff --git a/src/transformers/models/yoso/__init__.py b/src/transformers/models/yoso/__init__.py
index 5dff89595ca15d..400a0303c0c711 100644
--- a/src/transformers/models/yoso/__init__.py
+++ b/src/transformers/models/yoso/__init__.py
@@ -18,14 +18,17 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
-_import_structure = {
- "configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
-}
+_import_structure = {"configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"]}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_yoso"] = [
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
"YosoForMaskedLM",
@@ -42,7 +45,12 @@
if TYPE_CHECKING:
from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_yoso import (
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
YosoForMaskedLM,
diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py
index bcd9c516cc8baa..50013ca03209e9 100644
--- a/src/transformers/models/yoso/modeling_yoso.py
+++ b/src/transformers/models/yoso/modeling_yoso.py
@@ -816,7 +816,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py
index 6e3b4404cd043c..6d665b35566f2c 100644
--- a/src/transformers/onnx/__main__.py
+++ b/src/transformers/onnx/__main__.py
@@ -15,9 +15,8 @@
from argparse import ArgumentParser
from pathlib import Path
-from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer
-from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
-from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
+from ..models.auto import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
+from ..onnx.utils import get_preprocessor
from ..utils import logging
from .convert import export, validate_model_outputs
from .features import FeaturesManager
@@ -43,6 +42,13 @@ def main():
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
+ parser.add_argument(
+ "--preprocessor",
+ type=str,
+ choices=["auto", "tokenizer", "feature_extractor", "processor"],
+ default="auto",
+ help="Which type of preprocessor to use. 'auto' tries to automatically detect it.",
+ )
# Retrieve CLI arguments
args = parser.parse_args()
@@ -51,15 +57,17 @@ def main():
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
- # Check the modality of the inputs and instantiate the appropriate preprocessor
- # TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well
- config = AutoConfig.from_pretrained(args.model)
- if config.model_type in TOKENIZER_MAPPING_NAMES:
+ # Instantiate the appropriate preprocessor
+ if args.preprocessor == "auto":
+ preprocessor = get_preprocessor(args.model)
+ elif args.preprocessor == "tokenizer":
preprocessor = AutoTokenizer.from_pretrained(args.model)
- elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
+ elif args.preprocessor == "feature_extractor":
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
+ elif args.preprocessor == "processor":
+ preprocessor = AutoProcessor.from_pretrained(args.model)
else:
- raise ValueError(f"Unsupported model type: {config.model_type}")
+ raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
# Allocate the model
model = FeaturesManager.get_model_from_feature(
diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py
index 8f886a5d7a440d..f97d61ea401703 100644
--- a/src/transformers/onnx/config.py
+++ b/src/transformers/onnx/config.py
@@ -74,12 +74,11 @@ class OnnxConfig(ABC):
default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = {
+ "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
+ "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
+ "masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
- "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
- "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
- "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
- "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"question-answering": OrderedDict(
{
@@ -87,7 +86,9 @@ class OnnxConfig(ABC):
"end_logits": {0: "batch", 1: "sequence"},
}
),
- "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
+ "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
+ "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
+ "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
@@ -292,7 +293,8 @@ def generate_dummy_inputs(
raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.warning("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -409,7 +411,8 @@ def num_layers(self) -> int:
"""
if not hasattr(self._config, "num_layers"):
raise AttributeError(
- "could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
+ "could not find the number of layers attribute in the model configuration, override the num_layers"
+ " property of the model OnnxConfig to solve this"
)
return self._config.num_layers
@@ -421,7 +424,8 @@ def num_attention_heads(self) -> int:
"""
if not hasattr(self._config, "num_attention_heads"):
raise AttributeError(
- "could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
+ "could not find the number of attention heads attribute in the model configuration, override the"
+ " num_attention_heads property of the model OnnxConfig to solve this"
)
return self._config.num_attention_heads
@@ -456,8 +460,10 @@ def generate_dummy_inputs(
)
if "attention_mask" in common_inputs:
+ mask_dtype = common_inputs["attention_mask"].dtype
common_inputs["attention_mask"] = torch.cat(
- [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
+ [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)],
+ dim=1,
)
common_inputs["past_key_values"] = []
@@ -527,7 +533,8 @@ def num_layers(self) -> Tuple[int]:
num_layers = (self._config.encoder_layers, self._config.decoder_layers)
else:
raise AttributeError(
- "could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this"
+ "could not find the number of encoder and decoder layers attributes in the model configuration,"
+ " override the num_layers property of the model OnnxConfig to solve this"
)
return num_layers
@@ -542,7 +549,9 @@ def num_attention_heads(self) -> Tuple[int]:
num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads)
else:
raise AttributeError(
- "could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this"
+ "could not find the number of attention heads for the encoder and the decoder attributes in the"
+ " model configuration, override the num_attention_heads property of the model OnnxConfig to solve"
+ " this"
)
return num_attention_heads
diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py
index 69aca2a43accc9..43224532e6d23f 100644
--- a/src/transformers/onnx/convert.py
+++ b/src/transformers/onnx/convert.py
@@ -68,7 +68,7 @@ def check_onnxruntime_requirements(minimum_version: Version):
raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
- f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
)
except ImportError:
@@ -86,6 +86,7 @@ def export_pytorch(
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
+ device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch model to an ONNX Intermediate Representation (IR)
@@ -101,6 +102,8 @@ def export_pytorch(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
+ device (`str`, *optional*, defaults to `cpu`):
+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -111,7 +114,8 @@ def export_pytorch(
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -136,6 +140,10 @@ def export_pytorch(
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
+ device = torch.device(device)
+ if device.type == "cuda" and torch.cuda.is_available():
+ model.to(device)
+ model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
@@ -168,9 +176,13 @@ def export_pytorch(
message = str(err)
if (
message
- == "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter."
+ == "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without"
+ " setting use_external_data_format parameter."
):
- message = "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without setting use_external_data_format parameter or try with torch 1.10+."
+ message = (
+ "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export"
+ " without setting use_external_data_format parameter or try with torch 1.10+."
+ )
raise RuntimeError(message)
else:
raise err
@@ -227,7 +239,8 @@ def export_tensorflow(
raise ValueError("You cannot provide both a tokenizer and preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -262,6 +275,7 @@ def export(
opset: int,
output: Path,
tokenizer: "PreTrainedTokenizer" = None,
+ device: str = "cpu",
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
@@ -277,6 +291,9 @@ def export(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
+ device (`str`, *optional*, defaults to `cpu`):
+ The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
+ export on CUDA devices.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -288,11 +305,15 @@ def export(
"Please install torch or tensorflow first."
)
+ if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda":
+ raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
+
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -306,11 +327,12 @@ def export(
if not config.is_torch_support_available:
logger.warning(
- f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}, got: {torch_version}"
+ f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
+ f" got: {torch_version}"
)
if is_torch_available() and issubclass(type(model), PreTrainedModel):
- return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer)
+ return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)
@@ -332,7 +354,8 @@ def validate_model_outputs(
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validatethe model outputs.")
if tokenizer is not None:
warnings.warn(
- "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
+ "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
+ " `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
@@ -350,6 +373,8 @@ def validate_model_outputs(
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])
# Compute outputs from the reference model
+ if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
+ reference_model.to("cpu")
ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}
diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py
index a4d3a49388d601..66dc321d26d9fc 100644
--- a/src/transformers/onnx/features.py
+++ b/src/transformers/onnx/features.py
@@ -1,34 +1,9 @@
from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type, Union
+import transformers
+
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
-from ..models.albert import AlbertOnnxConfig
-from ..models.bart import BartOnnxConfig
-from ..models.beit import BeitOnnxConfig
-from ..models.bert import BertOnnxConfig
-from ..models.big_bird import BigBirdOnnxConfig
-from ..models.blenderbot import BlenderbotOnnxConfig
-from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
-from ..models.camembert import CamembertOnnxConfig
-from ..models.convbert import ConvBertOnnxConfig
-from ..models.data2vec import Data2VecTextOnnxConfig
-from ..models.deit import DeiTOnnxConfig
-from ..models.distilbert import DistilBertOnnxConfig
-from ..models.electra import ElectraOnnxConfig
-from ..models.flaubert import FlaubertOnnxConfig
-from ..models.gpt2 import GPT2OnnxConfig
-from ..models.gpt_neo import GPTNeoOnnxConfig
-from ..models.gptj import GPTJOnnxConfig
-from ..models.ibert import IBertOnnxConfig
-from ..models.layoutlm import LayoutLMOnnxConfig
-from ..models.m2m_100 import M2M100OnnxConfig
-from ..models.marian import MarianOnnxConfig
-from ..models.mbart import MBartOnnxConfig
-from ..models.roberta import RobertaOnnxConfig
-from ..models.roformer import RoFormerOnnxConfig
-from ..models.t5 import T5OnnxConfig
-from ..models.vit import ViTOnnxConfig
-from ..models.xlm_roberta import XLMRobertaOnnxConfig
from ..utils import logging
from .config import OnnxConfig
@@ -61,19 +36,20 @@
)
if not is_torch_available() and not is_tf_available():
logger.warning(
- "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
+ "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models"
+ " without one of these libraries installed."
)
def supported_features_mapping(
- *supported_features: str, onnx_config_cls: Type[OnnxConfig] = None
+ *supported_features: str, onnx_config_cls: str = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
- onnx_config_cls: The OnnxConfig class corresponding to the model.
+ onnx_config_cls: The OnnxConfig full name corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
@@ -81,13 +57,16 @@ def supported_features_mapping(
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
+ config_cls = transformers
+ for attr_name in onnx_config_cls.split("."):
+ config_cls = getattr(config_cls, attr_name)
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
- mapping[feature] = partial(onnx_config_cls.with_past, task=task)
+ mapping[feature] = partial(config_cls.with_past, task=task)
else:
- mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
+ mapping[feature] = partial(config_cls.from_model_config, task=feature)
return mapping
@@ -129,7 +108,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=AlbertOnnxConfig,
+ onnx_config_cls="models.albert.AlbertOnnxConfig",
),
"bart": supported_features_mapping(
"default",
@@ -140,18 +119,11 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
- onnx_config_cls=BartOnnxConfig,
+ onnx_config_cls="models.bart.BartOnnxConfig",
),
- "mbart": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "sequence-classification",
- "question-answering",
- onnx_config_cls=MBartOnnxConfig,
+ # BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
+ "beit": supported_features_mapping(
+ "default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
),
"bert": supported_features_mapping(
"default",
@@ -161,7 +133,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=BertOnnxConfig,
+ onnx_config_cls="models.bert.BertOnnxConfig",
),
"big-bird": supported_features_mapping(
"default",
@@ -171,16 +143,36 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=BigBirdOnnxConfig,
+ onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
),
- "ibert": supported_features_mapping(
+ "bigbird-pegasus": supported_features_mapping(
"default",
- "masked-lm",
+ "default-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
"sequence-classification",
- "multiple-choice",
- "token-classification",
"question-answering",
- onnx_config_cls=IBertOnnxConfig,
+ onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
+ ),
+ "blenderbot": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
+ ),
+ "blenderbot-small": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
@@ -190,7 +182,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=CamembertOnnxConfig,
+ onnx_config_cls="models.camembert.CamembertOnnxConfig",
),
"convbert": supported_features_mapping(
"default",
@@ -199,40 +191,35 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=ConvBertOnnxConfig,
+ onnx_config_cls="models.convbert.ConvBertOnnxConfig",
),
- "distilbert": supported_features_mapping(
+ "convnext": supported_features_mapping(
+ "default",
+ "image-classification",
+ onnx_config_cls="models.convnext.ConvNextOnnxConfig",
+ ),
+ "data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=DistilBertOnnxConfig,
+ onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
),
- "flaubert": supported_features_mapping(
+ "deit": supported_features_mapping(
+ "default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
+ ),
+ "distilbert": supported_features_mapping(
"default",
"masked-lm",
- "causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=FlaubertOnnxConfig,
+ onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
),
- "marian": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "causal-lm",
- "causal-lm-with-past",
- onnx_config_cls=MarianOnnxConfig,
- ),
- "m2m-100": supported_features_mapping(
- "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
- ),
- "roberta": supported_features_mapping(
+ "electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
@@ -240,12 +227,9 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=RobertaOnnxConfig,
+ onnx_config_cls="models.electra.ElectraOnnxConfig",
),
- "t5": supported_features_mapping(
- "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
- ),
- "xlm-roberta": supported_features_mapping(
+ "flaubert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
@@ -253,7 +237,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=XLMRobertaOnnxConfig,
+ onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
),
"gpt2": supported_features_mapping(
"default",
@@ -262,7 +246,7 @@ class FeaturesManager:
"causal-lm-with-past",
"sequence-classification",
"token-classification",
- onnx_config_cls=GPT2OnnxConfig,
+ onnx_config_cls="models.gpt2.GPT2OnnxConfig",
),
"gptj": supported_features_mapping(
"default",
@@ -271,7 +255,7 @@ class FeaturesManager:
"causal-lm-with-past",
"question-answering",
"sequence-classification",
- onnx_config_cls=GPTJOnnxConfig,
+ onnx_config_cls="models.gptj.GPTJOnnxConfig",
),
"gpt-neo": supported_features_mapping(
"default",
@@ -279,60 +263,87 @@ class FeaturesManager:
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
- onnx_config_cls=GPTNeoOnnxConfig,
+ onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
),
- "layoutlm": supported_features_mapping(
+ "ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
+ "multiple-choice",
"token-classification",
- onnx_config_cls=LayoutLMOnnxConfig,
+ "question-answering",
+ onnx_config_cls="models.ibert.IBertOnnxConfig",
),
- "electra": supported_features_mapping(
+ "layoutlm": supported_features_mapping(
"default",
"masked-lm",
- "causal-lm",
"sequence-classification",
- "multiple-choice",
"token-classification",
- "question-answering",
- onnx_config_cls=ElectraOnnxConfig,
+ onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
),
- "vit": supported_features_mapping(
- "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
- ),
- "beit": supported_features_mapping(
- "default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
+ "longt5": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.longt5.LongT5OnnxConfig",
),
- "deit": supported_features_mapping(
- "default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
+ "marian": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ "causal-lm",
+ "causal-lm-with-past",
+ onnx_config_cls="models.marian.MarianOnnxConfig",
),
- "blenderbot": supported_features_mapping(
+ "mbart": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
- onnx_config_cls=BlenderbotOnnxConfig,
+ "sequence-classification",
+ "question-answering",
+ onnx_config_cls="models.mbart.MBartOnnxConfig",
),
- "blenderbot-small": supported_features_mapping(
+ "mobilebert": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
+ ),
+ "m2m-100": supported_features_mapping(
"default",
"default-with-past",
- "causal-lm",
- "causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
- onnx_config_cls=BlenderbotSmallOnnxConfig,
+ onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
),
- "data2vec-text": supported_features_mapping(
+ "perceiver": supported_features_mapping(
+ "image-classification",
+ "masked-lm",
+ "sequence-classification",
+ onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
+ ),
+ "resnet": supported_features_mapping(
+ "default",
+ "image-classification",
+ onnx_config_cls="models.resnet.ResNetOnnxConfig",
+ ),
+ "roberta": supported_features_mapping(
"default",
"masked-lm",
+ "causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
- onnx_config_cls=Data2VecTextOnnxConfig,
+ onnx_config_cls="models.roberta.RobertaOnnxConfig",
),
"roformer": supported_features_mapping(
"default",
@@ -343,7 +354,46 @@ class FeaturesManager:
"multiple-choice",
"question-answering",
"token-classification",
- onnx_config_cls=RoFormerOnnxConfig,
+ onnx_config_cls="models.roformer.RoFormerOnnxConfig",
+ ),
+ "squeezebert": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
+ ),
+ "t5": supported_features_mapping(
+ "default",
+ "default-with-past",
+ "seq2seq-lm",
+ "seq2seq-lm-with-past",
+ onnx_config_cls="models.t5.T5OnnxConfig",
+ ),
+ "vit": supported_features_mapping(
+ "default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
+ ),
+ "xlm": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "causal-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.xlm.XLMOnnxConfig",
+ ),
+ "xlm-roberta": supported_features_mapping(
+ "default",
+ "masked-lm",
+ "causal-lm",
+ "sequence-classification",
+ "multiple-choice",
+ "token-classification",
+ "question-answering",
+ onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
),
}
@@ -416,8 +466,7 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError(
- f"Unknown task: {feature}. "
- f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
+ f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return task_to_automodel[task]
@@ -470,8 +519,22 @@ def check_supported_model_or_raise(
model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
if feature not in model_features:
raise ValueError(
- f"{model.config.model_type} doesn't support feature {feature}. "
- f"Supported values are: {model_features}"
+ f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
)
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
+
+ def get_config(model_type: str, feature: str) -> OnnxConfig:
+ """
+ Gets the OnnxConfig for a model_type and feature combination.
+
+ Args:
+ model_type (`str`):
+ The model type to retrieve the config for.
+ feature (`str`):
+ The feature to retrieve the config for.
+
+ Returns:
+ `OnnxConfig`: config for the combination
+ """
+ return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
diff --git a/src/transformers/onnx/utils.py b/src/transformers/onnx/utils.py
index def160e6c7bb39..9672b0a96af88f 100644
--- a/src/transformers/onnx/utils.py
+++ b/src/transformers/onnx/utils.py
@@ -14,6 +14,11 @@
from ctypes import c_float, sizeof
from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+
+if TYPE_CHECKING:
+ from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
class ParameterFormat(Enum):
@@ -61,3 +66,44 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
+
+
+def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]:
+ """
+ Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
+
+ Args:
+ model_name (`str`): Name of the model for which a preprocessor are loaded.
+
+ Returns:
+ `Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]`:
+ If a processor is found, it is returned. Otherwise, if a tokenizer or a feature extractor exists, it is
+ returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
+ `None` if no preprocessor is found.
+ """
+ # Avoid circular imports by only importing this here.
+ from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
+
+ try:
+ return AutoProcessor.from_pretrained(model_name)
+ except (ValueError, OSError, KeyError):
+ tokenizer, feature_extractor = None, None
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ except (OSError, KeyError):
+ pass
+ try:
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
+ except (OSError, KeyError):
+ pass
+
+ if tokenizer is not None and feature_extractor is not None:
+ raise ValueError(
+ f"Couldn't auto-detect preprocessor for {model_name}. Found both a tokenizer and a feature extractor."
+ )
+ elif tokenizer is None and feature_extractor is None:
+ return None
+ elif tokenizer is not None:
+ return tokenizer
+ else:
+ return feature_extractor
diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py
index 60b9dca7831b76..b957acb6de9395 100644
--- a/src/transformers/optimization.py
+++ b/src/transformers/optimization.py
@@ -304,8 +304,9 @@ def __init__(
):
if not no_deprecation_warning:
warnings.warn(
- "This implementation of AdamW is deprecated and will be removed in a future version. Use the"
- " PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning",
+ "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
+ " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
+ " warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py
index 1350669e45167a..d5c59f30782e4e 100755
--- a/src/transformers/pipelines/__init__.py
+++ b/src/transformers/pipelines/__init__.py
@@ -29,6 +29,7 @@
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
+from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import http_get, is_tf_available, is_torch_available, logging
from .audio_classification import AudioClassificationPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
@@ -60,6 +61,7 @@
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
+from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
@@ -93,6 +95,7 @@
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
@@ -108,6 +111,7 @@
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
+ AutoModelForVisualQuestionAnswering,
)
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
@@ -120,6 +124,7 @@
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
+ "vqa": "visual-question-answering",
}
SUPPORTED_TASKS = {
"audio-classification": {
@@ -189,6 +194,19 @@
},
"type": "text",
},
+ "visual-question-answering": {
+ "impl": VisualQuestionAnsweringPipeline,
+ "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
+ "tf": (),
+ "default": {
+ "model": {
+ "pt": "dandelin/vilt-b32-finetuned-vqa",
+ "tokenizer": "dandelin/vilt-b32-finetuned-vqa",
+ "feature_extractor": "dandelin/vilt-b32-finetuned-vqa",
+ },
+ },
+ "type": "multimodal",
+ },
"fill-mask": {
"impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
@@ -373,7 +391,7 @@ def pipeline(
task: str = None,
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
- tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
+ tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
@@ -500,15 +518,15 @@ def pipeline(
if model is None and tokenizer is not None:
raise RuntimeError(
- "Impossible to instantiate a pipeline with tokenizer specified but not the model "
- "as the provided tokenizer may not be compatible with the default model. "
- "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing tokenizer."
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
+ " may not be compatible with the default model. Please provide a PreTrainedModel class or a"
+ " path/identifier to a pretrained model when providing tokenizer."
)
if model is None and feature_extractor is not None:
raise RuntimeError(
- "Impossible to instantiate a pipeline with feature_extractor specified but not the model "
- "as the provided feature_extractor may not be compatible with the default model. "
- "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided"
+ " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
+ " or a path/identifier to a pretrained model when providing feature_extractor."
)
if task is None and model is not None:
@@ -642,7 +660,9 @@ def pipeline(
kwargs["decoder"] = decoder
except ImportError as e:
logger.warning(
- f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
+ f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install"
+ " `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install"
+ f" https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)
if task == "translation" and model.config.task_specific_params:
diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index d54a17df1e9dd3..1565463e0e753b 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -75,14 +75,19 @@ def _pad(items, key, padding_value, padding_side):
# Others include `attention_mask` etc...
shape = items[0][key].shape
dim = len(shape)
- if dim == 4:
+ if key == "pixel_values":
# This is probable image so padding shouldn't be necessary
# B, C, H, W
return torch.cat([item[key] for item in items], dim=0)
max_length = max(item[key].shape[1] for item in items)
+ min_length = min(item[key].shape[1] for item in items)
dtype = items[0][key].dtype
if dim == 2:
+ if max_length == min_length:
+ # Bypass for `ImageGPT` which doesn't provide a padding value, yet
+ # we can consistently pad since the size should be matching
+ return torch.cat([item[key] for item in items], dim=0)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
elif dim == 3:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
@@ -139,13 +144,18 @@ def inner(items):
for item in items:
if set(item.keys()) != keys:
raise ValueError(
- f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} != {keys})"
+ f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !="
+ f" {keys})"
)
# input_values, input_pixels, input_ids, ...
padded = {}
for key in keys:
if key in {"input_ids"}:
- _padding_value = t_padding_value
+ # ImageGPT uses a feature extractor
+ if feature_extractor is not None:
+ _padding_value = f_padding_value
+ else:
+ _padding_value = t_padding_value
elif key in {"input_values", "pixel_values", "input_features"}:
_padding_value = f_padding_value
elif key in {"p_mask", "special_tokens_mask"}:
@@ -692,7 +702,7 @@ def predict(self, X):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
- the associated CUDA device id.
+ the associated CUDA device id. You can pass native `torch.device` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
@@ -749,7 +759,10 @@ def __init__(
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
- self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
+ if is_torch_available() and isinstance(device, torch.device):
+ self.device = device
+ else:
+ self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output
# Special handling
@@ -856,6 +869,8 @@ def _ensure_tensor_on_device(self, inputs, device):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
+ if device == torch.device("cpu") and inputs.dtype in {torch.float16, torch.bfloat16}:
+ inputs = inputs.float()
return inputs.to(device)
else:
return inputs
@@ -879,7 +894,8 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
logger.error(
- f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}."
+ f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
+ f" {supported_models}."
)
@abstractmethod
@@ -994,7 +1010,8 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs):
self.call_count += 1
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
warnings.warn(
- "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
+ "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
+ " dataset",
UserWarning,
)
@@ -1058,7 +1075,8 @@ def get_iterator(
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if num_workers > 1:
logger.warning(
- "For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable, setting `num_workers=1` to guarantee correctness."
+ "For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable,"
+ " setting `num_workers=1` to guarantee correctness."
)
num_workers = 1
dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py
index 517b457a654b46..f461f6faa2af65 100644
--- a/src/transformers/pipelines/fill_mask.py
+++ b/src/transformers/pipelines/fill_mask.py
@@ -167,7 +167,7 @@ def get_target_ids(self, targets, top_k=None):
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
- f"We cannot replace it with anything meaningful, ignoring it"
+ "We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py
index c629f703a030f0..0f5fbf0370e708 100644
--- a/src/transformers/pipelines/question_answering.py
+++ b/src/transformers/pipelines/question_answering.py
@@ -228,8 +228,8 @@ def __call__(self, *args, **kwargs):
max_answer_len (`int`, *optional*, defaults to 15):
The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
max_seq_len (`int`, *optional*, defaults to 384):
- The maximum length of the total sentence (context + question) after tokenization. The context will be
- split in several chunks (using `doc_stride`) if needed.
+ The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
+ model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
max_question_len (`int`, *optional*, defaults to 64):
The maximum length of the question after tokenization. It will be truncated if needed.
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
@@ -279,7 +279,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
truncation="only_second" if question_first else "only_first",
max_length=max_seq_len,
stride=doc_stride,
- return_tensors="np",
return_token_type_ids=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
@@ -294,12 +293,10 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
- p_mask = np.asarray(
- [
- [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
- for span_id in range(num_spans)
- ]
- )
+ p_mask = [
+ [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
+ for span_id in range(num_spans)
+ ]
features = []
for span_idx in range(num_spans):
@@ -316,8 +313,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
for cls_index in cls_indices:
p_mask[span_idx][cls_index] = 0
submask = p_mask[span_idx]
- if isinstance(submask, np.ndarray):
- submask = submask.tolist()
features.append(
SquadFeatures(
input_ids=input_ids_span_idx,
@@ -344,7 +339,7 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_questio
for i, feature in enumerate(features):
fw_args = {}
others = {}
- model_input_names = self.tokenizer.model_input_names + ["p_mask"]
+ model_input_names = self.tokenizer.model_input_names + ["p_mask", "token_type_ids"]
for k, v in feature.__dict__.items():
if k in model_input_names:
@@ -398,8 +393,11 @@ def postprocess(
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
# Normalize logits and spans to retrieve the answer
- start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
- end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
+ start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True))
+ start_ = start_ / start_.sum()
+
+ end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True))
+ end_ = end_ / end_.sum()
if handle_impossible_answer:
min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item())
diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py
index d94bb6d061ff6d..25dcd320cf4f6a 100644
--- a/src/transformers/pipelines/table_question_answering.py
+++ b/src/transformers/pipelines/table_question_answering.py
@@ -56,14 +56,14 @@ def __call__(self, table=None, query=None, **kwargs):
tqa_pipeline_inputs = table
else:
raise ValueError(
- f"If keyword argument `table` is a list of dictionaries, each dictionary should have a `table` "
- f"and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
+ "If keyword argument `table` is a list of dictionaries, each dictionary should have a `table`"
+ f" and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
)
elif Dataset is not None and isinstance(table, Dataset) or isinstance(table, types.GeneratorType):
return table
else:
raise ValueError(
- f"Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
+ "Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
f"is {type(table)})"
)
else:
diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py
index 3d3f4e533d45ad..590c87c02201c7 100644
--- a/src/transformers/pipelines/text_classification.py
+++ b/src/transformers/pipelines/text_classification.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Dict
import numpy as np
@@ -72,15 +73,26 @@ def __init__(self, **kwargs):
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
)
- def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs):
+ def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
+ # Using "" as default argument because we're going to use `top_k=None` in user code to declare
+ # "No top_k"
preprocess_params = tokenizer_kwargs
postprocess_params = {}
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
return_all_scores = self.model.config.return_all_scores
- if return_all_scores is not None:
- postprocess_params["return_all_scores"] = return_all_scores
+ if isinstance(top_k, int) or top_k is None:
+ postprocess_params["top_k"] = top_k
+ postprocess_params["_legacy"] = False
+ elif return_all_scores is not None:
+ warnings.warn(
+ "`return_all_scores` is now deprecated, use `top_k=1` if you want similar functionnality", UserWarning
+ )
+ if return_all_scores:
+ postprocess_params["top_k"] = None
+ else:
+ postprocess_params["top_k"] = 1
if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()]
@@ -94,10 +106,11 @@ def __call__(self, *args, **kwargs):
Classify the text(s) given as inputs.
Args:
- args (`str` or `List[str]`):
- One or several texts (or one list of prompts) to classify.
- return_all_scores (`bool`, *optional*, defaults to `False`):
- Whether to return scores for all labels.
+ args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
+ One or several texts to classify. In order to use text pairs for your classification, you can send a
+ dictionnary containing `{"text", "text_pair"}` keys, or a list of those.
+ top_k (`int`, *optional*, defaults to `1`):
+ How many results to return.
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
values:
@@ -120,10 +133,10 @@ def __call__(self, *args, **kwargs):
- **label** (`str`) -- The label predicted.
- **score** (`float`) -- The corresponding probability.
- If `self.return_all_scores=True`, one such dictionary is returned per label.
+ If `top_k` is used, one such dictionary is returned per label.
"""
result = super().__call__(*args, **kwargs)
- if isinstance(args[0], str):
+ if isinstance(args[0], str) and isinstance(result, dict):
# This pipeline is odd, and return a list when single item is run
return [result]
else:
@@ -131,12 +144,28 @@ def __call__(self, *args, **kwargs):
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
+ if isinstance(inputs, dict):
+ return self.tokenizer(**inputs, return_tensors=return_tensors, **tokenizer_kwargs)
+ elif isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], list) and len(inputs[0]) == 2:
+ # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
+ return self.tokenizer(
+ text=inputs[0][0], text_pair=inputs[0][1], return_tensors=return_tensors, **tokenizer_kwargs
+ )
+ elif isinstance(inputs, list):
+ # This is likely an invalid usage of the pipeline attempting to pass text pairs.
+ raise ValueError(
+ "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a"
+ ' dictionnary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
+ )
return self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
def _forward(self, model_inputs):
return self.model(**model_inputs)
- def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False):
+ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
+ # `_legacy` is used to determine if we're running the naked pipeline and in backward
+ # compatibility mode, or if running the pipeline with `pipeline(..., top_k=1)` we're running
+ # the more natural result containing the list.
# Default value before `set_parameters`
if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
@@ -160,7 +189,14 @@ def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=F
else:
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
- if return_all_scores:
- return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
- else:
+ if top_k == 1 and _legacy:
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
+
+ dict_scores = [
+ {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
+ ]
+ if not _legacy:
+ dict_scores.sort(key=lambda x: x["score"], reverse=True)
+ if top_k is not None:
+ dict_scores = dict_scores[:top_k]
+ return dict_scores
diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py
index dbaa0a9df75a1f..4f210871a2441b 100644
--- a/src/transformers/pipelines/text_generation.py
+++ b/src/transformers/pipelines/text_generation.py
@@ -103,7 +103,8 @@ def _sanitize_parameters(
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
- f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
+ f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
+ " [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
@@ -192,7 +193,8 @@ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **gene
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
- "We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
+ "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
+ " models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py
index 4ea8d114150def..72f0c5c9c73823 100644
--- a/src/transformers/pipelines/token_classification.py
+++ b/src/transformers/pipelines/token_classification.py
@@ -133,11 +133,13 @@ def _sanitize_parameters(
if grouped_entities is not None:
warnings.warn(
- f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
+ "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to"
+ f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
- f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
+ "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to"
+ f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if aggregation_strategy is not None:
diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py
new file mode 100644
index 00000000000000..34a7a3b10d40fe
--- /dev/null
+++ b/src/transformers/pipelines/visual_question_answering.py
@@ -0,0 +1,115 @@
+from typing import Union
+
+from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
+from .base import PIPELINE_INIT_ARGS, Pipeline
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from ..image_utils import load_image
+
+if is_torch_available():
+ from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
+
+logger = logging.get_logger(__name__)
+
+
+@add_end_docstrings(PIPELINE_INIT_ARGS)
+class VisualQuestionAnsweringPipeline(Pipeline):
+ """
+ Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only
+ available in PyTorch.
+
+ This visual question answering pipeline can currently be loaded from [`pipeline`] using the following task
+ identifiers: `"visual-question-answering", "vqa"`.
+
+ The models that this pipeline can use are models that have been fine-tuned on a visual question answering task. See
+ the up-to-date list of available models on
+ [huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING)
+
+ def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
+ preprocess_params, postprocess_params = {}, {}
+ if padding is not None:
+ preprocess_params["padding"] = padding
+ if truncation is not None:
+ preprocess_params["truncation"] = truncation
+ if top_k is not None:
+ postprocess_params["top_k"] = top_k
+ return preprocess_params, {}, postprocess_params
+
+ def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs):
+ r"""
+ Answers open-ended questions about images. The pipeline accepts several types of inputs which are detailed
+ below:
+
+ - `pipeline(image=image, question=question)`
+ - `pipeline({"image": image, "question": question})`
+ - `pipeline([{"image": image, "question": question}])`
+ - `pipeline([{"image": image, "question": question}, {"image": image, "question": question}])`
+
+ Args:
+ image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
+ The pipeline handles three types of images:
+
+ - A string containing a http link pointing to an image
+ - A string containing a local path to an image
+ - An image loaded in PIL directly
+
+ The pipeline accepts either a single image or a batch of images. If given a single image, it can be
+ broadcasted to multiple questions.
+ question (`str`, `List[str]`):
+ The question(s) asked. If given a single question, it can be broadcasted to multiple images.
+ top_k (`int`, *optional*, defaults to 5):
+ The number of top labels that will be returned by the pipeline. If the provided number is higher than
+ the number of labels available in the model configuration, it will default to the number of labels.
+ Return:
+ A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
+
+ - **label** (`str`) -- The label identified by the model.
+ - **score** (`int`) -- The score attributed by the model for that label.
+ """
+ if isinstance(image, (Image.Image, str)) and isinstance(question, str):
+ inputs = {"image": image, "question": question}
+ else:
+ """
+ Supports the following format
+ - {"image": image, "question": question}
+ - [{"image": image, "question": question}]
+ - Generator and datasets
+ """
+ inputs = image
+ results = super().__call__(inputs, **kwargs)
+ return results
+
+ def preprocess(self, inputs, padding=False, truncation=False):
+ image = load_image(inputs["image"])
+ model_inputs = self.tokenizer(
+ inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
+ )
+ image_features = self.feature_extractor(images=image, return_tensors=self.framework)
+ model_inputs.update(image_features)
+ return model_inputs
+
+ def _forward(self, model_inputs):
+ model_outputs = self.model(**model_inputs)
+ return model_outputs
+
+ def postprocess(self, model_outputs, top_k=5):
+ if top_k > self.model.config.num_labels:
+ top_k = self.model.config.num_labels
+
+ if self.framework == "pt":
+ probs = model_outputs.logits.sigmoid()[0]
+ scores, ids = probs.topk(top_k)
+ else:
+ raise ValueError(f"Unsupported framework: {self.framework}")
+
+ scores = scores.tolist()
+ ids = ids.tolist()
+ return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py
index 9d5d5bd61b781e..f98c87166ca037 100644
--- a/src/transformers/pipelines/zero_shot_classification.py
+++ b/src/transformers/pipelines/zero_shot_classification.py
@@ -86,7 +86,8 @@ def _parse_and_tokenize(
if self.tokenizer.pad_token is None:
# Override for tokenizers not supporting padding
logger.error(
- "Tokenizer was not supporting padding necessary for zero-shot, attempting to use `pad_token=eos_token`"
+ "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
+ " `pad_token=eos_token`"
)
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 36f56d2eeb29c6..1aebe8f4e2debf 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -22,6 +22,7 @@
import sys
import tempfile
import unittest
+from collections.abc import Mapping
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
@@ -39,12 +40,14 @@
is_wandb_available,
)
from .utils import (
+ is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
+ is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@@ -68,6 +71,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
+ is_torchdynamo_available,
is_vision_available,
)
@@ -203,10 +207,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
- if not _run_slow_tests:
- return unittest.skip("test is slow")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def tooslow(test_case):
@@ -227,10 +228,7 @@ def custom_tokenizers(test_case):
Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
environment variable to a truthy value to run them.
"""
- if not _run_custom_tokenizers:
- return unittest.skip("test of custom tokenizers")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
def require_git_lfs(test_case):
@@ -240,34 +238,29 @@ def require_git_lfs(test_case):
git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
variable to a truthy value to run them.
"""
- if not _run_git_lfs_tests:
- return unittest.skip("test of git lfs workflow")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
- if not is_rjieba_available():
- return unittest.skip("test requires rjieba")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
def require_tf2onnx(test_case):
- if not is_tf2onnx_available():
- return unittest.skip("test requires tf2onnx")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
def require_onnx(test_case):
- if not is_onnx_available():
- return unittest.skip("test requires ONNX")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
def require_timm(test_case):
@@ -277,10 +270,7 @@ def require_timm(test_case):
These tests are skipped when Timm isn't installed.
"""
- if not is_timm_available():
- return unittest.skip("test requires Timm")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
def require_torch(test_case):
@@ -290,10 +280,17 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed.
"""
- if not is_torch_available():
- return unittest.skip("test requires PyTorch")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_intel_extension_for_pytorch(test_case):
+ """
+ Decorator marking a test that requires Intel Extension for PyTorch.
+
+ These tests are skipped when Intel Extension for PyTorch isn't installed.
+
+ """
+ return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
def require_torch_scatter(test_case):
@@ -303,10 +300,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed.
"""
- if not is_scatter_available():
- return unittest.skip("test requires PyTorch scatter")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)
def require_tensorflow_probability(test_case):
@@ -316,89 +310,65 @@ def require_tensorflow_probability(test_case):
These tests are skipped when TensorFlow probability isn't installed.
"""
- if not is_tensorflow_probability_available():
- return unittest.skip("test requires TensorFlow probability")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
+ test_case
+ )
def require_torchaudio(test_case):
"""
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
"""
- if not is_torchaudio_available():
- return unittest.skip("test requires torchaudio")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
def require_tf(test_case):
"""
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
"""
- if not is_tf_available():
- return unittest.skip("test requires TensorFlow")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
- if not is_flax_available():
- test_case = unittest.skip("test requires JAX & Flax")(test_case)
- return test_case
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
def require_sentencepiece(test_case):
"""
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
"""
- if not is_sentencepiece_available():
- return unittest.skip("test requires SentencePiece")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
def require_scipy(test_case):
"""
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
"""
- if not is_scipy_available():
- return unittest.skip("test requires Scipy")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
def require_tokenizers(test_case):
"""
Decorator marking a test that requires š¤ Tokenizers. These tests are skipped when š¤ Tokenizers isn't installed.
"""
- if not is_tokenizers_available():
- return unittest.skip("test requires tokenizers")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
- if not is_pandas_available():
- return unittest.skip("test requires pandas")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
def require_pytesseract(test_case):
"""
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
"""
- if not is_pytesseract_available():
- return unittest.skip("test requires PyTesseract")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
def require_scatter(test_case):
@@ -406,10 +376,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
- if not is_scatter_available():
- return unittest.skip("test requires PyTorch Scatter")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)
def require_pytorch_quantization(test_case):
@@ -417,10 +384,9 @@ def require_pytorch_quantization(test_case):
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
Quantization Toolkit isn't installed.
"""
- if not is_pytorch_quantization_available():
- return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
+ test_case
+ )
def require_vision(test_case):
@@ -428,30 +394,21 @@ def require_vision(test_case):
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
installed.
"""
- if not is_vision_available():
- return unittest.skip("test requires vision")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
def require_ftfy(test_case):
"""
Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
"""
- if not is_ftfy_available():
- return unittest.skip("test requires ftfy")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
def require_spacy(test_case):
"""
Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
"""
- if not is_spacy_available():
- return unittest.skip("test requires spacy")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
def require_torch_multi_gpu(test_case):
@@ -466,10 +423,7 @@ def require_torch_multi_gpu(test_case):
import torch
- if torch.cuda.device_count() < 2:
- return unittest.skip("test requires multiple GPUs")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
def require_torch_non_multi_gpu(test_case):
@@ -481,10 +435,7 @@ def require_torch_non_multi_gpu(test_case):
import torch
- if torch.cuda.device_count() > 1:
- return unittest.skip("test requires 0 or 1 GPU")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
def require_torch_up_to_2_gpus(test_case):
@@ -496,20 +447,14 @@ def require_torch_up_to_2_gpus(test_case):
import torch
- if torch.cuda.device_count() > 2:
- return unittest.skip("test requires 0 or 1 or 2 GPUs")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
- if not is_torch_tpu_available():
- return unittest.skip("test requires PyTorch TPU")
- else:
- return test_case
+ return unittest.skipUnless(is_torch_tpu_available(), "test requires PyTorch TPU")(test_case)
if is_torch_available():
@@ -531,44 +476,39 @@ def require_torch_tpu(test_case):
jax_device = None
+def require_torchdynamo(test_case):
+ """Decorator marking a test that requires TorchDynamo"""
+ return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
+
+
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
- if torch_device != "cuda":
- return unittest.skip("test requires CUDA")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_bf16(test_case):
- """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
- if not is_torch_bf16_available():
- return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
- else:
- return test_case
+ """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
+ return unittest.skipUnless(
+ is_torch_bf16_available(),
+ "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
+ )(test_case)
def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
- if not is_torch_tf32_available():
- return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(
+ is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
+ )(test_case)
def require_detectron2(test_case):
"""Decorator marking a test that requires detectron2."""
- if not is_detectron2_available():
- return unittest.skip("test requires `detectron2`")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
- if not is_faiss_available():
- return unittest.skip("test requires `faiss`")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
def require_optuna(test_case):
@@ -578,10 +518,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed.
"""
- if not is_optuna_available():
- return unittest.skip("test requires optuna")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
def require_ray(test_case):
@@ -591,10 +528,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed.
"""
- if not is_ray_available():
- return unittest.skip("test requires Ray/tune")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
def require_sigopt(test_case):
@@ -604,10 +538,7 @@ def require_sigopt(test_case):
These tests are skipped when SigOpt isn't installed.
"""
- if not is_sigopt_available():
- return unittest.skip("test requires SigOpt")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
def require_wandb(test_case):
@@ -617,10 +548,7 @@ def require_wandb(test_case):
These tests are skipped when wandb isn't installed.
"""
- if not is_wandb_available():
- return unittest.skip("test requires wandb")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
def require_soundfile(test_case):
@@ -630,80 +558,56 @@ def require_soundfile(test_case):
These tests are skipped when soundfile isn't installed.
"""
- if not is_soundfile_availble():
- return unittest.skip("test requires soundfile")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
- if not is_deepspeed_available():
- return unittest.skip("test requires deepspeed")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
- if not is_fairscale_available():
- return unittest.skip("test requires fairscale")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
- if not is_apex_available():
- return unittest.skip("test requires apex")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
- if not is_bitsandbytes_available():
- return unittest.skip("test requires bnb")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
"""
- if not is_phonemizer_available():
- return unittest.skip("test requires phonemizer")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
def require_pyctcdecode(test_case):
"""
Decorator marking a test that requires pyctcdecode
"""
- if not is_pyctcdecode_available():
- return unittest.skip("test requires pyctcdecode")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
def require_librosa(test_case):
"""
Decorator marking a test that requires librosa
"""
- if not is_librosa_available():
- return unittest.skip("test requires librosa")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
def cmd_exists(cmd):
@@ -714,10 +618,7 @@ def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
- if not cmd_exists("/usr/bin/time"):
- return unittest.skip("test requires /usr/bin/time")(test_case)
- else:
- return test_case
+ return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
def get_gpu_count():
@@ -1585,13 +1486,11 @@ def nested_simplify(obj, decimals=3):
"""
import numpy as np
- from transformers.tokenization_utils import BatchEncoding
-
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, np.ndarray):
return nested_simplify(obj.tolist())
- elif isinstance(obj, (dict, BatchEncoding)):
+ elif isinstance(obj, Mapping):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int, np.int64)):
return obj
@@ -1607,3 +1506,20 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
+
+
+def check_json_file_has_correct_format(file_path):
+ with open(file_path, "r") as f:
+ lines = f.readlines()
+ if len(lines) == 1:
+ # length can only be 1 if dict is empty
+ assert lines[0] == "{}"
+ else:
+ # otherwise make sure json has correct format (at least 3 lines)
+ assert len(lines) >= 3
+ # each key one line, ident should be 2, min length is 3
+ assert lines[0].strip() == "{"
+ for line in lines[1:-1]:
+ left_indent = len(lines[1]) - len(lines[1].lstrip())
+ assert left_indent == 2
+ assert lines[-1].strip() == "}"
diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py
index 694b55cedd3e7f..6d33266c03f4e7 100644
--- a/src/transformers/tokenization_utils.py
+++ b/src/transformers/tokenization_utils.py
@@ -250,7 +250,8 @@ def cut_text(self, text, offsets):
for end in offsets:
if start > end:
logger.error(
- "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
+ "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
+ " anyway."
)
continue
elif start == end:
@@ -627,11 +628,13 @@ def get_input_ids(text):
else:
if is_split_into_words:
raise ValueError(
- f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
+ " `is_split_into_words=True`."
)
else:
raise ValueError(
- f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
+ " integers."
)
if return_offsets_mapping:
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index d75b05c057866a..619e138ca126fb 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -24,6 +24,7 @@
import re
import warnings
from collections import OrderedDict, UserDict
+from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
@@ -1501,12 +1502,12 @@ def max_len_single_sentence(self, value) -> int:
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
if not self.deprecation_warnings.get("max_len_single_sentence", False):
logger.warning(
- "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
)
self.deprecation_warnings["max_len_single_sentence"] = True
else:
raise ValueError(
- "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
)
@max_len_sentences_pair.setter
@@ -1515,13 +1516,11 @@ def max_len_sentences_pair(self, value) -> int:
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
if not self.deprecation_warnings.get("max_len_sentences_pair", False):
logger.warning(
- "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
+ "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up."
)
self.deprecation_warnings["max_len_sentences_pair"] = True
else:
- raise ValueError(
- "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
- )
+ raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.")
def _set_processor_class(self, processor_class: str):
"""Sets processor class as an attribute."""
@@ -1529,9 +1528,10 @@ def _set_processor_class(self, processor_class: str):
def __repr__(self) -> str:
return (
- f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
- f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
- f"padding_side='{self.padding_side}', truncation_side='{self.truncation_side}', special_tokens={self.special_tokens_map_extended})"
+ f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}',"
+ f" vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast},"
+ f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
+ f" special_tokens={self.special_tokens_map_extended})"
)
def get_vocab(self) -> Dict[str, int]:
@@ -1872,10 +1872,10 @@ def _from_pretrained(
if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
logger.warning(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. "
- "It may result in unexpected tokenization. \n"
- f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n"
- f"The class this function is called from is '{cls.__name__}'."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this"
+ " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you"
+ f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called"
+ f" from is '{cls.__name__}'."
)
# Update with newly provided kwargs
@@ -1898,9 +1898,19 @@ def convert_added_tokens(obj: Union[AddedToken, Any]):
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
+
model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
if model_max_length is not None and isinstance(model_max_length, (int, float)):
- init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
+
+ model_max_length = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
+ # TODO(PVP) - uncomment following line in Transformers v5
+ # init_kwargs["model_max_length"] = model_max_length
+ # TODO(PVP) - remove in Transformers v5
+ # ---
+ init_kwargs["model_max_length"] = cls._eventually_correct_t5_max_length(
+ pretrained_model_name_or_path, model_max_length, init_kwargs.get("model_max_length")
+ )
+ # ---
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
@@ -1954,34 +1964,57 @@ def convert_added_tokens(obj: Union[AddedToken, Any]):
# Sort added tokens by index
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
+ # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
+ # individual tokens would repeatedly rebuild a trie, which can be slow.
+ is_last_special = None
+ tokens = []
+
for token, index in added_tok_encoder_sorted:
- if has_tokenizer_file and index != len(tokenizer) and tokenizer.convert_tokens_to_ids(token) != index:
+ current_index = len(tokenizer) + len(tokens)
+ if has_tokenizer_file and index != current_index and tokenizer.convert_tokens_to_ids(token) != index:
# Tokenizer fast: added token needs to either be in the vocabulary with the proper index or the
# index is the current length of the tokenizer (not in vocabulary)
raise ValueError(
f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found "
f"{index}."
)
- elif not has_tokenizer_file and index != len(tokenizer):
+ elif not has_tokenizer_file and index != current_index:
# Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the
# current length of the tokenizer.
raise ValueError(
f"Non-consecutive added token '{token}' found. "
- f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
+ f"Should have index {current_index} but has index {index} in saved vocabulary."
)
- # Safe to call on a tokenizer fast even if token already there.
- tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
+ is_special = bool(token in special_tokens)
+ if is_last_special is None or is_last_special == is_special:
+ tokens.append(token)
+ else:
+ tokenizer.add_tokens(tokens, special_tokens=is_last_special)
+ tokens = [token]
+ is_last_special = is_special
+
+ if tokens:
+ tokenizer.add_tokens(tokens, special_tokens=is_last_special)
# Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
added_tokens = tokenizer.sanitize_special_tokens()
if added_tokens:
logger.warning_advice(
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained."
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
+ " fine-tuned or trained."
)
return tokenizer
+ @staticmethod
+ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
+ # This method should be deleted in Transformers v5
+ # Its only purpose is to potentially throw a warning
+ # that incorrectly defined max lengths of T5's tokenizer are used
+ # which we will correct in Transformers v5.
+ return max_model_length
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -2085,13 +2118,15 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
custom_object_save(self, save_directory, config=tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(tokenizer_config, ensure_ascii=False))
+ out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(write_dict, ensure_ascii=False))
+ out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+ f.write(out_str)
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
@@ -2135,7 +2170,7 @@ def _save_pretrained(
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
- out_str = json.dumps(added_vocab, ensure_ascii=False)
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"added tokens file saved in {added_tokens_file}")
@@ -2251,11 +2286,11 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
logger.warning(
- "Truncation was not explicitly activated but `max_length` is provided a specific value, "
- "please use `truncation=True` to explicitly truncate examples to max length. "
- "Defaulting to 'longest_first' truncation strategy. "
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
- "more precisely by providing a specific strategy to `truncation`."
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
+ " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
+ " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
+ " tokenizer you can select this strategy more precisely by providing a specific strategy to"
+ " `truncation`."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
@@ -2297,14 +2332,14 @@ def _get_padding_truncation_strategies(
if truncation is False and old_truncation_strategy != "do_not_truncate":
if verbose:
warnings.warn(
- "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
- "use `truncation=True` to truncate examples to a max length. You can give a specific "
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
- "maximal input size of the model (e.g. 512 for Bert). "
- " If you have pairs of inputs, you can give a specific truncation strategy selected among "
- "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
- "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
- "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
+ "The `truncation_strategy` argument is deprecated and will be removed in a future version, use"
+ " `truncation=True` to truncate examples to a max length. You can give a specific length with"
+ " `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input"
+ " size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific"
+ " truncation strategy selected among `truncation='only_first'` (will only truncate the first"
+ " sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the"
+ " pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence"
+ " in the pairs).",
FutureWarning,
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
@@ -2327,8 +2362,8 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
logger.warning(
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no padding."
+ "Asking to pad to max_length but no maximum length is provided and the model has no"
+ " predefined maximum length. Default to no padding."
)
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
@@ -2340,8 +2375,8 @@ def _get_padding_truncation_strategies(
if verbose:
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
logger.warning(
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
- "Default to no truncation."
+ "Asking to truncate to max_length but no maximum length is provided and the model has"
+ " no predefined maximum length. Default to no truncation."
)
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
@@ -2365,7 +2400,7 @@ def _get_padding_truncation_strategies(
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
- f"Truncation and padding are both activated but "
+ "Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
@@ -2448,11 +2483,13 @@ def _is_valid_text_input(t):
if is_batched:
if isinstance(text_pair, str):
raise TypeError(
- "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
+ "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as"
+ " `text`."
)
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
- f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
+ f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+ f" {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus(
@@ -2768,7 +2805,7 @@ def pad(
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
@@ -2807,7 +2844,7 @@ def pad(
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
- f"Should be one of a python, numpy, pytorch or tensorflow object."
+ "Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in encoded_inputs.items():
@@ -3104,16 +3141,17 @@ def truncate_sequences(
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
- error_msg + "Please select another truncation strategy than "
+ error_msg
+ + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logger.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logger.warning(
- f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
- f"truncation strategy. So the returned list will always be empty even if some "
- f"tokens have been removed."
+ "truncation strategy. So the returned list will always be empty even if some "
+ "tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
@@ -3146,7 +3184,7 @@ def truncate_sequences(
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the second sequence has a length {len(pair_ids)}. "
f"Please select another truncation strategy than {truncation_strategy}, "
- f"for instance 'longest_first' or 'only_first'."
+ "for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py
index 4f85a842dd3d2d..cdb606e7c60d94 100644
--- a/src/transformers/tokenization_utils_fast.py
+++ b/src/transformers/tokenization_utils_fast.py
@@ -21,6 +21,7 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
+import tokenizers.pre_tokenizers as pre_tokenizers_fast
from tokenizers import Encoding as EncodingFast
from tokenizers import Tokenizer as TokenizerFast
from tokenizers.decoders import Decoder as DecoderFast
@@ -567,8 +568,8 @@ def _save_pretrained(
if self.slow_tokenizer_class is None and legacy_format is True:
raise ValueError(
- "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You "
- "might consider leaving the legacy_format at `None` or setting it to `False`."
+ "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
+ " might consider leaving the legacy_format at `None` or setting it to `False`."
)
save_slow = (
@@ -585,7 +586,7 @@ def _save_pretrained(
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
- out_str = json.dumps(added_vocab, ensure_ascii=False)
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
@@ -699,6 +700,8 @@ def train_new_from_iterator(
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
kwargs["unk_token"] = unk_token
+ if tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel":
+ kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index eed68631384009..663500716f1d85 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -17,6 +17,8 @@
"""
import contextlib
+import functools
+import glob
import inspect
import math
import os
@@ -63,10 +65,10 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
-from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
+from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
-from .modeling_utils import PreTrainedModel, unwrap_model
+from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
@@ -103,40 +105,50 @@
BestRun,
EvalLoopOutput,
EvalPrediction,
+ FSDPOption,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
+ RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
+ enable_full_determinism,
+ find_executable_batch_size,
get_last_checkpoint,
has_length,
number_of_arguments,
+ seed_worker,
set_seed,
speed_metrics,
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
CONFIG_NAME,
+ WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
find_labels,
get_full_repo_name,
is_apex_available,
is_datasets_available,
is_in_notebook,
+ is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
+ is_torchdynamo_available,
logging,
)
+from .utils.generic import ContextManagers
_is_torch_generator_available = False
-_is_native_amp_available = False
+_is_native_cuda_amp_available = False
+_is_native_cpu_amp_available = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
@@ -151,8 +163,10 @@
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
- _is_native_amp_available = True
- from torch.cuda.amp import autocast
+ _is_native_cuda_amp_available = True
+
+if version.parse(torch.__version__) >= version.parse("1.10"):
+ _is_native_cpu_amp_available = True
if is_datasets_available():
import datasets
@@ -296,7 +310,7 @@ def __init__(
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model
- set_seed(self.args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
@@ -321,8 +335,9 @@ def __init__(
else:
if model_init is not None:
warnings.warn(
- "`Trainer` requires either a `model` or `model_init` argument, but not both. "
- "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
FutureWarning,
)
self.model_init = model_init
@@ -339,6 +354,10 @@ def __init__(
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
+ if len(args.fsdp) > 0:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
+ )
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
@@ -356,6 +375,30 @@ def __init__(
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
+ self.fsdp = None
+ if len(args.fsdp) > 0:
+ if args.deepspeed:
+ raise ValueError(
+ "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if args.local_rank == -1:
+ raise ValueError("Using fsdp only works in distributed training.")
+
+ # dep_version_check("torch>=1.12.0.dev20220418+cu113")
+ # Would have to update setup.py with torch>=1.12.0.dev20220418+cu113
+ # which isn't ideally given that it's a dev version
+ # and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
+ # below is the current alternative.
+ if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
+ raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
+
+ if FSDPOption.FULL_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.FULL_SHARD
+ elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
+ self.fsdp = ShardingStrategy.SHARD_GRAD_OP
+
# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
@@ -363,12 +406,14 @@ def __init__(
# and we only use deepspeed for training at the moment
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP
+ # 5. FSDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or args.deepspeed
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
+ or (self.fsdp is not None)
):
self.place_model_on_device = False
@@ -397,11 +442,11 @@ def __init__(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
- if (self.sharded_ddp is not None or args.deepspeed) and (
+ if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
- "Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
+ "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
@@ -446,28 +491,55 @@ def __init__(
# Mixed precision setup
self.use_apex = False
- self.use_amp = False
+ self.use_cuda_amp = False
+ self.use_cpu_amp = False
+
+ # Mixed precision setup for SageMaker Model Parallel
+ if is_sagemaker_mp_enabled():
+ # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+ if args.bf16:
+ raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+ # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+ if args.fp16 != smp.state.cfg.fp16:
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
+ f"but FP16 provided in trainer argument is {args.fp16},"
+ f"setting to {smp.state.cfg.fp16}"
+ )
+ args.fp16 = smp.state.cfg.fp16
if args.fp16 or args.bf16:
+ if self.fsdp is not None:
+ raise ValueError(
+ "Mixed precision is currently not supported for FSDP."
+ "Please do not set arguments related to `mixed_precision`"
+ )
if args.half_precision_backend == "auto":
- if _is_native_amp_available:
- args.half_precision_backend = "amp"
+ if args.device == torch.device("cpu"):
+ if args.fp16:
+ raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+ elif _is_native_cpu_amp_available:
+ args.half_precision_backend = "cpu_amp"
+ else:
+ raise ValueError("Tried to use cpu amp but native cpu amp is not available")
else:
- if args.bf16:
+ if _is_native_cuda_amp_available:
+ args.half_precision_backend = "cuda_amp"
+ elif args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
else:
args.half_precision_backend = "apex"
+
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
- if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
- if args.half_precision_backend == "amp":
- self.use_amp = True
+ if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
+ # deepspeed and SageMaker Model Parallel manage their own half precision
+ if args.half_precision_backend == "cuda_amp":
+ self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
- if is_sagemaker_mp_enabled():
- self.scaler = smp.amp.GradScaler()
- elif self.sharded_ddp is not None:
+ if self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
@@ -475,15 +547,24 @@ def __init__(
self.scaler = GradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
+ elif args.half_precision_backend == "cpu_amp":
+ self.use_cpu_amp = True
+ self.amp_dtype = torch.bfloat16
else:
if not is_apex_available():
raise ImportError(
- "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
+ "Using FP16 with APEX but APEX is not installed, please refer to"
+ " https://www.github.com/nvidia/apex."
)
self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
- if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
+ if (
+ is_sagemaker_mp_enabled()
+ and self.use_cuda_amp
+ and args.max_grad_norm is not None
+ and args.max_grad_norm > 0
+ ):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
@@ -510,6 +591,9 @@ def __init__(
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+ # Internal variables to keep track of the original batch size
+ self._train_batch_size = args.train_batch_size
+
# very last
self._memory_tracker.stop_and_update_metrics()
@@ -557,27 +641,31 @@ def _move_model_to_device(self, model, device):
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
- def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
- if not self.args.remove_unused_columns:
- return dataset
+ def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
- self._signature_columns += ["label", "label_ids"]
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
- ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0:
- dset_description = "" if description is None else f"in the {description} set "
+ dset_description = "" if description is None else f"in the {description} set"
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
- f" you can safely ignore this message."
+ " you can safely ignore this message."
)
- columns = [k for k in self._signature_columns if k in dataset.column_names]
+ columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
@@ -587,6 +675,24 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio
else:
return dataset.remove_columns(ignored_columns)
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
@@ -673,14 +779,17 @@ def get_train_dataloader(self) -> DataLoader:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
+ data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
train_dataset,
- batch_size=self.args.train_batch_size,
+ batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
@@ -689,7 +798,7 @@ def get_train_dataloader(self) -> DataLoader:
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -698,12 +807,13 @@ def get_train_dataloader(self) -> DataLoader:
return DataLoader(
train_dataset,
- batch_size=self.args.train_batch_size,
+ batch_size=self._train_batch_size,
sampler=train_sampler,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
+ worker_init_fn=seed_worker,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
@@ -749,9 +859,12 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ data_collator = self.data_collator
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
@@ -765,7 +878,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
return DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -776,7 +889,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
@@ -793,8 +906,12 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
method are automatically removed. It must implement `__len__`.
"""
+ data_collator = self.data_collator
+
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
@@ -808,7 +925,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -820,7 +937,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
test_dataset,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
- collate_fn=self.data_collator,
+ collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.dataloader_pin_memory,
)
@@ -834,7 +951,10 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
`create_scheduler`) in a subclass.
"""
self.create_optimizer()
- self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
+ self.create_scheduler(
+ num_training_steps=num_training_steps,
+ optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer,
+ )
def create_optimizer(self):
"""
@@ -936,6 +1056,10 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = torch.optim.SGD
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = torch.optim.Adagrad
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
@@ -986,7 +1110,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
for key, value in params.items():
if not hasattr(self.args, key):
logger.warning(
- f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+ " `TrainingArguments`."
)
continue
old_attr = getattr(self.args, key, None)
@@ -1054,7 +1179,56 @@ def call_model_init(self, trial=None):
return model
- def _wrap_model(self, model, training=True):
+ def torch_jit_model_eval(self, model, dataloader, training=False):
+ if not training:
+ if dataloader is None:
+ logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+ return model
+ jit_inputs = []
+ example_batch = next(iter(dataloader))
+ for key in example_batch:
+ example_tensor = torch.ones_like(example_batch[key])
+ jit_inputs.append(example_tensor)
+ jit_inputs = tuple(jit_inputs)
+ try:
+ jit_model = model.eval()
+ with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
+ jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
+ jit_model = torch.jit.freeze(jit_model)
+ jit_model(**example_batch)
+ model = jit_model
+ except (RuntimeError, TypeError) as e:
+ logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+ return model
+
+ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
+ if not is_ipex_available():
+ raise ImportError(
+ "Using IPEX but IPEX is not installed, please refer to"
+ " https://github.com/intel/intel-extension-for-pytorch."
+ )
+
+ import intel_extension_for_pytorch as ipex
+
+ if not training:
+ model.eval()
+ model = ipex.optimize(model, dtype=dtype, level="O1")
+ else:
+ if not model.training:
+ model.train()
+ model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
+
+ return model
+
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if self.args.use_ipex:
+ dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
+ model = self.ipex_optimize_model(model, training, dtype=dtype)
+
+ if self.args.jit_mode_eval:
+ model = self.torch_jit_model_eval(model, dataloader, training)
+
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
@@ -1101,6 +1275,33 @@ def _wrap_model(self, model, training=True):
cpu_offload=cpu_offload,
).to(self.args.device)
+ # Distributed training using PyTorch FSDP
+ if self.fsdp is not None:
+ # PyTorch FSDP!
+ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp.wrap import default_auto_wrap_policy
+
+ if FSDPOption.OFFLOAD in self.args.fsdp:
+ cpu_offload = CPUOffload(offload_params=True)
+ else:
+ cpu_offload = CPUOffload(offload_params=False)
+
+ auto_wrap_policy = None
+ if FSDPOption.AUTO_WRAP in self.args.fsdp:
+ if self.args.fsdp_min_num_params > 0:
+ auto_wrap_policy = functools.partial(
+ default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
+ )
+
+ if type(model) != FSDP:
+ # XXX: Breaking the self.model convention but I see no way around it for now.
+ self.model = model = FSDP(
+ model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
+ )
+ if FSDPOption.OFFLOAD not in self.args.fsdp:
+ model.to(self.args.device)
+
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
@@ -1150,7 +1351,8 @@ def train(
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
- resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
+ if resume_from_checkpoint is False:
+ resume_from_checkpoint = None
# memory metrics - must set up as early as possible
self._memory_tracker.start()
@@ -1180,7 +1382,7 @@ def train(
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
- set_seed(args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
@@ -1192,33 +1394,8 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
- if resume_from_checkpoint is not None:
- if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
- raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
-
- logger.info(f"Loading model from {resume_from_checkpoint}).")
-
- if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
- config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
- checkpoint_version = config.transformers_version
- if checkpoint_version is not None and checkpoint_version != __version__:
- logger.warning(
- f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
- f"Transformers but your current version is {__version__}. This is not recommended and could "
- "yield to errors or unwanted behaviors."
- )
-
- if args.deepspeed:
- # will be resumed in deepspeed_init
- pass
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
- # If the model is on the GPU, it still works!
- self._load_state_dict_in_model(state_dict)
-
- # release memory
- del state_dict
+ if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
+ self._load_from_checkpoint(resume_from_checkpoint)
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
@@ -1226,6 +1403,20 @@ def train(
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
+ inner_training_loop = find_executable_batch_size(
+ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+ )
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self._train_batch_size = batch_size
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
@@ -1262,7 +1453,8 @@ def train(
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
- f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
@@ -1270,13 +1462,17 @@ def train(
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
- "Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)."
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = (
- self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
+ self.sharded_ddp is not None
+ and self.sharded_ddp != ShardedDDPOption.SIMPLE
+ or is_sagemaker_mp_enabled()
+ or self.fsdp is not None
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
@@ -1299,6 +1495,9 @@ def train(
model = self._wrap_model(self.model_wrapped)
+ if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
+ self._load_from_checkpoint(resume_from_checkpoint, model)
+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
@@ -1419,6 +1618,9 @@ def train(
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
step = -1
for step, inputs in enumerate(epoch_iterator):
@@ -1481,7 +1683,9 @@ def train(
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
- if hasattr(self.optimizer, "clip_grad_norm"):
+ if is_sagemaker_mp_enabled() and args.fp16:
+ self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
@@ -1529,7 +1733,7 @@ def train(
break
if step < 0:
logger.warning(
- f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ "There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
@@ -1561,34 +1765,10 @@ def train(
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
- logger.info(
- f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
- )
-
- best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
- if os.path.exists(best_model_path):
- if self.deepspeed:
- # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
- deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
- self.model = deepspeed_engine.module
- self.model_wrapped = deepspeed_engine
- self.deepspeed = deepspeed_engine
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
- self.deepspeed.load_checkpoint(
- self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
- )
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- state_dict = torch.load(best_model_path, map_location="cpu")
- # If the model is on the GPU, it still works!
- self._load_state_dict_in_model(state_dict)
- else:
- logger.warning(
- f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
- "on multiple nodes, you should activate `--save_on_each_node`."
- )
+ self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
@@ -1609,8 +1789,89 @@ def train(
return TrainOutput(self.state.global_step, train_loss, metrics)
- def _load_state_dict_in_model(self, state_dict):
- load_result = self.model.load_state_dict(state_dict, strict=False)
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+
+ if model is None:
+ model = self.model
+ strict_load = is_sagemaker_mp_enabled()
+
+ if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
+ os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+ ):
+ raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+ logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+ if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
+ config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
+ checkpoint_version = config.transformers_version
+ if checkpoint_version is not None and checkpoint_version != __version__:
+ logger.warning(
+ f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+ f"Transformers but your current version is {__version__}. This is not recommended and could "
+ "yield to errors or unwanted behaviors."
+ )
+
+ if self.args.deepspeed:
+ # will be resumed in deepspeed_init
+ pass
+ elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ # If the model is on the GPU, it still works!
+ load_result = model.load_state_dict(state_dict, strict=strict_load)
+ if not strict_load:
+ self._issue_warnings_after_load(load_result)
+ # release memory
+ del state_dict
+ else:
+ # We load the sharded checkpoint
+ load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
+ if not strict_load:
+ self._issue_warnings_after_load(load_result)
+
+ def _load_best_model(self):
+ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+ best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ strict_load = is_sagemaker_mp_enabled()
+ model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+ if os.path.exists(best_model_path):
+ if self.deepspeed:
+
+ if self.model_wrapped is not None:
+ # this removes the pre-hooks from the previous engine
+ self.model_wrapped.destroy()
+ self.model_wrapped = None
+
+ # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
+ deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
+ self,
+ num_training_steps=self.args.max_steps,
+ resume_from_checkpoint=self.state.best_model_checkpoint,
+ )
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ # If the model is on the GPU, it still works!
+ load_result = model.load_state_dict(state_dict, strict=strict_load)
+ if not strict_load:
+ self._issue_warnings_after_load(load_result)
+ elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
+ load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
+ if not strict_load:
+ self._issue_warnings_after_load(load_result)
+ else:
+ logger.warning(
+ f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+ "on multiple nodes, you should activate `--save_on_each_node`."
+ )
+
+ def _issue_warnings_after_load(self, load_result):
if len(load_result.missing_keys) != 0:
if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
@@ -1663,7 +1924,7 @@ def _load_rng_state(self, checkpoint):
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank != -1:
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
- if not os.path.isfile(os.path.join(checkpoint, rng_file)):
+ if not os.path.isfile(rng_file):
logger.info(
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
@@ -1741,17 +2002,21 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
- if smp.rdp_rank() == 0:
- # Consolidate the state dict on all processed of rdp_rank 0
- opt_state_dict = self.optimizer.state_dict()
- # Save it and the scheduler on the main process
- if self.args.should_save:
- torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
- with warnings.catch_warnings(record=True) as caught_warnings:
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- if self.do_grad_scaling:
- torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+ smp.barrier()
+ if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+ smp.save(
+ opt_state_dict,
+ os.path.join(output_dir, OPTIMIZER_NAME),
+ partial=True,
+ v3=smp.state.cfg.shard_optimizer_state,
+ )
+ if self.args.should_save:
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
@@ -1800,6 +2065,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
+
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
@@ -1822,9 +2088,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return
- if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
- os.path.join(checkpoint, SCHEDULER_NAME)
- ):
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+ if is_sagemaker_mp_enabled()
+ else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+ )
+ if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
@@ -1840,9 +2109,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
- self.optimizer.load_state_dict(
- torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
- )
+ if is_sagemaker_mp_enabled():
+
+ def opt_load_hook(mod, opt):
+ opt.load_state_dict(
+ smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False
+ )
+
+ self.model_wrapped.register_post_step_hook(opt_load_hook)
+ else:
+ self.optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
+ )
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
@@ -1993,16 +2271,46 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
return inputs
+ def compute_loss_context_manager(self):
+ """
+ A helper wrapper to group together context managers.
+ """
+ return ContextManagers(
+ [
+ self.torchdynamo_smart_context_manager(),
+ self.autocast_smart_context_manager(),
+ ]
+ )
+
+ def torchdynamo_smart_context_manager(self):
+ """
+ A helper wrapper that creates an appropriate context manager for `torchdynamo`.
+ """
+ ctx_manager = contextlib.nullcontext()
+ if is_torchdynamo_available():
+ import torchdynamo
+ from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
+
+ if self.args.torchdynamo == "eager":
+ ctx_manager = torchdynamo.optimize("eager")
+ elif self.args.torchdynamo == "nvfuser":
+ ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
+ return ctx_manager
+
def autocast_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
- if self.use_amp:
+ if self.use_cuda_amp or self.use_cpu_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
- ctx_manager = autocast(dtype=self.amp_dtype)
+ ctx_manager = (
+ torch.cpu.amp.autocast(dtype=self.amp_dtype)
+ if self.use_cpu_amp
+ else torch.cuda.amp.autocast(dtype=self.amp_dtype)
+ )
else:
- ctx_manager = autocast()
+ ctx_manager = torch.cuda.amp.autocast()
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
@@ -2030,11 +2338,10 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
- scaler = self.scaler if self.do_grad_scaling else None
- loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
@@ -2118,7 +2425,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
- ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
+ or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ or self.fsdp is not None
):
state_dict = self.model.state_dict()
@@ -2146,8 +2455,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
# This must be called on all ranks
if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
- "deepspeed.save_16bit_model didn't save the model, since stage3_gather_16bit_weights_on_model_save=false. "
- "Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
+ "deepspeed.save_16bit_model didn't save the model, since"
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+ " zero_to_fp32.py to recover weights"
)
self.deepspeed.save_checkpoint(output_dir)
@@ -2428,7 +2738,7 @@ def evaluation_loop(
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
- model = self._wrap_model(self.model, training=False)
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
@@ -2725,7 +3035,7 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
@@ -2735,7 +3045,7 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
@@ -2989,7 +3299,7 @@ def prediction_loop(
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
- model = self._wrap_model(self.model, training=False)
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py
index 92abe1ed50634c..06875b74e1dae2 100644
--- a/src/transformers/trainer_callback.py
+++ b/src/transformers/trainer_callback.py
@@ -556,7 +556,8 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
if metric_value is None:
logger.warning(
- f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
+ f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
+ " is disabled"
)
return
diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py
index d76552c3755eee..f669e6f32ae4cd 100644
--- a/src/transformers/trainer_pt_utils.py
+++ b/src/transformers/trainer_pt_utils.py
@@ -22,6 +22,7 @@
import os
import sys
import warnings
+from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from logging import StreamHandler
@@ -54,8 +55,22 @@
logger = logging.get_logger(__name__)
+def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
+ if isinstance(tensor_or_array, torch.Tensor):
+ if hasattr(torch, "atleast_1d"):
+ tensor_or_array = torch.atleast_1d(tensor_or_array)
+ elif tensor_or_array.ndim < 1:
+ tensor_or_array = tensor_or_array[None]
+ else:
+ tensor_or_array = np.atleast_1d(tensor_or_array)
+ return tensor_or_array
+
+
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
+ tensor1 = atleast_1d(tensor1)
+ tensor2 = atleast_1d(tensor2)
+
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
return torch.cat((tensor1, tensor2), dim=0)
@@ -71,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
+ array1 = atleast_1d(array1)
+ array2 = atleast_1d(array2)
+
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
return np.concatenate((array1, array2), axis=0)
@@ -111,7 +129,7 @@ def find_batch_size(tensors):
result = find_batch_size(t)
if result is not None:
return result
- elif isinstance(tensors, (dict, BatchEncoding)):
+ elif isinstance(tensors, Mapping):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
@@ -148,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
- if tensors.ndim == 0:
- tensors = tensors[None]
+ tensors = atleast_1d(tensors)
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
@@ -159,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
- if len(tensor.shape) <= 0:
- tensor = tensor[None]
+ tensor = atleast_1d(tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
@@ -1004,15 +1020,10 @@ def get_parameter_names(model, forbidden_layer_types):
import smdistributed.modelparallel.torch as smp
@smp.step()
- def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
- with torch.cuda.amp.autocast(enabled=(scaler is not None)):
- outputs = model(**inputs)
-
+ def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
+ outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
- if scaler is not None:
- loss = scaler.scale(loss).squeeze()
-
model.backward(loss)
return loss
@@ -1030,7 +1041,7 @@ def smp_gather(tensor):
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
- all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
+ all_tensors = [atleast_1d(t) for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)
def smp_nested_concat(tensor):
diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py
index 5513b58bef94b9..7a290fe149de84 100644
--- a/src/transformers/trainer_seq2seq.py
+++ b/src/transformers/trainer_seq2seq.py
@@ -183,7 +183,7 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
- with self.autocast_smart_context_manager():
+ with self.compute_loss_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py
index 71c2e691d2a7ab..737dd4deaf6887 100644
--- a/src/transformers/trainer_tf.py
+++ b/src/transformers/trainer_tf.py
@@ -34,7 +34,14 @@
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
-from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
+from .trainer_utils import (
+ PREFIX_CHECKPOINT_DIR,
+ EvalPrediction,
+ IntervalStrategy,
+ PredictionOutput,
+ enable_full_determinism,
+ set_seed,
+)
from .training_args_tf import TFTrainingArguments
from .utils import logging
@@ -134,7 +141,7 @@ def __init__(
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
- set_seed(self.args.seed)
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
def get_train_tfdataset(self) -> tf.data.Dataset:
"""
diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py
index 4450bfde646eff..46fd0cdd05b6f4 100644
--- a/src/transformers/trainer_utils.py
+++ b/src/transformers/trainer_utils.py
@@ -25,7 +25,7 @@
import re
import threading
import time
-from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
+from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
@@ -36,6 +36,7 @@
is_torch_available,
is_torch_cuda_available,
is_torch_tpu_available,
+ requires_backends,
)
@@ -46,6 +47,39 @@
import tensorflow as tf
+def seed_worker(_):
+ """
+ Helper function to set worker seed during Dataloader initialization.
+ """
+ worker_seed = torch.initial_seed() % 2**32
+ set_seed(worker_seed)
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
+ """
+ # set seed first
+ set_seed(seed)
+
+ if is_torch_available():
+ # Ā Enable PyTorch deterministic mode. This potentially requires either the environment
+ # Ā variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if is_tf_available():
+ tf.config.experimental.enable_op_determinism()
+
+
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
@@ -211,7 +245,7 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
def default_hp_space_ray(trial) -> Dict[str, float]:
from .integrations import is_ray_tune_available
- assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
+ assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
from ray import tune
return {
@@ -355,6 +389,7 @@ class TrainerMemoryTracker:
stages = {
"__init__": "init",
"train": "train",
+ "_inner_training_loop": "train",
"evaluate": "eval",
"predict": "test",
}
@@ -582,3 +617,80 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
+
+
+def find_executable_batch_size(
+ function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
+):
+ """
+ Args:
+ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
+ CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as
+ its first argument.
+ function (`callable`, *optional*)
+ A function to wrap
+ starting_batch_size (`int`, *optional*)
+ The batch size to try and fit into memory
+ auto_find_batch_size (`bool`, *optional*)
+ If False, will just execute `function`
+ """
+ if function is None:
+ return functools.partial(
+ find_executable_batch_size,
+ starting_batch_size=starting_batch_size,
+ auto_find_batch_size=auto_find_batch_size,
+ )
+
+ if auto_find_batch_size:
+ requires_backends(find_executable_batch_size, "accelerate")
+ import accelerate.memory_utils as mem_utils
+
+ return mem_utils.find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
+
+ return functools.partial(function, batch_size=starting_batch_size)
+
+
+class FSDPOption(ExplicitEnum):
+ FULL_SHARD = "full_shard"
+ SHARD_GRAD_OP = "shard_grad_op"
+ OFFLOAD = "offload"
+ AUTO_WRAP = "auto_wrap"
+
+
+class RemoveColumnsCollator:
+ """Wrap the data collator to remove unused columns before they are passed to the collator."""
+
+ def __init__(
+ self,
+ data_collator,
+ signature_columns,
+ logger=None,
+ model_name: Optional[str] = None,
+ description: Optional[str] = None,
+ ):
+ self.data_collator = data_collator
+ self.signature_columns = signature_columns
+ self.logger = logger
+ self.description = description
+ self.model_name = model_name
+ self.message_logged = False
+
+ def _remove_columns(self, feature: dict) -> dict:
+ if not isinstance(feature, dict):
+ return feature
+ if not self.message_logged and self.logger and self.model_name:
+ ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if self.description is None else f"in the {self.description} set"
+ self.logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
+ " you can safely ignore this message."
+ )
+ self.message_logged = True
+ return {k: v for k, v in feature.items() if k in self.signature_columns}
+
+ def __call__(self, features: List[dict]):
+ features = [self._remove_columns(feature) for feature in features]
+ return self.data_collator(features)
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index cc0a5ec835704d..7d7f68dfc8bc50 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -23,7 +23,14 @@
from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption
-from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
+from .trainer_utils import (
+ EvaluationStrategy,
+ FSDPOption,
+ HubStrategy,
+ IntervalStrategy,
+ SchedulerType,
+ ShardedDDPOption,
+)
from .utils import (
ExplicitEnum,
cached_property,
@@ -69,6 +76,15 @@ def default_logdir() -> str:
return os.path.join("runs", current_time + "_" + socket.gethostname())
+def get_int_from_env(env_keys, default):
+ """Returns the first positive env value found in the `env_keys` list or the default."""
+ for e in env_keys:
+ val = int(os.environ.get(e, -1))
+ if val >= 0:
+ return val
+ return default
+
+
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
@@ -80,6 +96,8 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
+ SGD = "sgd"
+ ADAGRAD = "adagrad"
@dataclass
@@ -227,9 +245,14 @@ class TrainingArguments:
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
+ jit_mode_eval (`bool`, *optional*, defaults to `False`):
+ Whether or not to use PyTorch jit trace for inference.
+ use_ipex (`bool`, *optional*, defaults to `False`):
+ Use Intel extension for PyTorch when it is available. [IPEX
+ installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
- NVIDIA architecture. This is an experimental API and it may change.
+ NVIDIA architecture or using CPU (no_cuda). This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
@@ -238,9 +261,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
- The backend to use for mixed precision training. Must be one of `"auto"`, `"amp"` or `"apex"`. `"auto"`
- will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the
- requested backend.
+ The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
+ `"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
+ will force the requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
@@ -280,8 +303,7 @@ class TrainingArguments:
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
set to warn or lower (default), `False` otherwise.
remove_unused_columns (`bool`, *optional*, defaults to `True`):
- If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the
- model forward method.
+ Whether or not to automatically remove the columns unused by the model forward method.
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
label_names (`List[str]`, *optional*):
@@ -294,8 +316,8 @@ class TrainingArguments:
- When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in the case
- it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+ When set to `True`, the parameters `save_strategy` needs to be the same as `evaluation_strategy`, and in
+ the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
@@ -331,6 +353,18 @@ class TrainingArguments:
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for `False` and `["simple"]` for `True`.
+ fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`):
+ Use PyTorch Distributed Parallel Training (in distributed training only).
+
+ A list of options along the following:
+
+ - `"full_shard"`: Shard parameters, gradients and optimizer states.
+ - `"shard_grad_op"`: Shard optimizer states and gradients.
+ - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
+ `"shard_grad_op"`).
+ - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
+ fsdp_min_num_params (`int`, *optional*, defaults to `0`):
+ FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
@@ -424,6 +458,21 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
+ auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+ Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
+ CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+ full_determinism (`bool`, *optional*, defaults to `False`)
+ If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
+ distributed training
+ torchdynamo (`str`, *optional*):
+ The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
+ "nvfuser]. This is an experimental API and subject to change.
+ ray_scope (`str`, *optional*, defaults to `"last"`):
+ The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
+ then use the last checkpoint of all trials, compare those, and select the best one. However, other options
+ are also available. See the [Ray documentation](
+ https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
+ more options.
"""
output_dir: str = field(
@@ -461,15 +510,19 @@ class TrainingArguments:
per_gpu_train_batch_size: Optional[int] = field(
default=None,
metadata={
- "help": "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
- "Batch size per GPU/TPU core/CPU for training."
+ "help": (
+ "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for training."
+ )
},
)
per_gpu_eval_batch_size: Optional[int] = field(
default=None,
metadata={
- "help": "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
- "Batch size per GPU/TPU core/CPU for evaluation."
+ "help": (
+ "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
+ "Batch size per GPU/TPU core/CPU for evaluation."
+ )
},
)
@@ -485,7 +538,10 @@ class TrainingArguments:
eval_delay: Optional[float] = field(
default=0,
metadata={
- "help": "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the evaluation_strategy."
+ "help": (
+ "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
+ " evaluation_strategy."
+ )
},
)
@@ -513,7 +569,11 @@ class TrainingArguments:
log_level: Optional[str] = field(
default="passive",
metadata={
- "help": "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'.",
+ "help": (
+ "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
+ " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
+ " lets the application set the level. Defaults to 'passive'."
+ ),
"choices": trainer_log_levels.keys(),
},
)
@@ -527,7 +587,10 @@ class TrainingArguments:
log_on_each_node: bool = field(
default=True,
metadata={
- "help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
+ "help": (
+ "When doing a multinode distributed training, whether to log once per node or just once on the main"
+ " node."
+ )
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
@@ -555,16 +618,34 @@ class TrainingArguments:
save_on_each_node: bool = field(
default=False,
metadata={
- "help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
+ "help": (
+ "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
+ " only on the main one"
+ )
},
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
- data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ jit_mode_eval: bool = field(
+ default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
+ )
+ use_ipex: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Use Intel extension for PyTorch when it is available, installation:"
+ " 'https://github.com/intel/intel-extension-for-pytorch'"
+ )
+ },
+ )
bf16: bool = field(
default=False,
metadata={
- "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
+ "help": (
+ "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
+ " architecture or using CPU (no_cuda). This is an experimental API and it may change."
+ )
},
)
fp16: bool = field(
@@ -582,26 +663,35 @@ class TrainingArguments:
)
half_precision_backend: str = field(
default="auto",
- metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
+ metadata={
+ "help": "The backend to be used for half precision.",
+ "choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
+ },
)
bf16_full_eval: bool = field(
default=False,
metadata={
- "help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change."
+ "help": (
+ "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
+ " change."
+ )
},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
- tf32: bool = field(
+ tf32: Optional[bool] = field(
default=None,
metadata={
- "help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
+ "help": (
+ "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental"
+ " API and it may change."
+ )
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
- xpu_backend: str = field(
+ xpu_backend: Optional[str] = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
@@ -611,26 +701,33 @@ class TrainingArguments:
tpu_metrics_debug: bool = field(
default=False,
metadata={
- "help": "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+ "help": (
+ "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+ )
},
)
debug: str = field(
default="",
metadata={
- "help": "Whether or not to enable debug mode. Current options: "
- "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
- "`tpu_metrics_debug` (print debug metrics on TPU)."
+ "help": (
+ "Whether or not to enable debug mode. Current options: "
+ "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
+ "`tpu_metrics_debug` (print debug metrics on TPU)."
+ )
},
)
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
- eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
+ eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={
- "help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process."
+ "help": (
+ "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded"
+ " in the main process."
+ )
},
)
@@ -666,22 +763,51 @@ class TrainingArguments:
ignore_data_skip: bool = field(
default=False,
metadata={
- "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
+ "help": (
+ "When resuming training, whether or not to skip the first epochs and batches to get to the same"
+ " training data."
+ )
},
)
sharded_ddp: str = field(
default="",
metadata={
- "help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
- "should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
- "like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
- "with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
+ "help": (
+ "Whether or not to use sharded DDP training (in distributed training only). The base option should be"
+ " `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
+ " this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
+ " with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
+ ),
+ },
+ )
+ fsdp: str = field(
+ default="",
+ metadata={
+ "help": (
+ "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
+ " only). The base option should be `full_shard` or `shard_grad_op` and you can add CPU-offload to"
+ " `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op offload`. You can"
+ " add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard auto_wrap` or"
+ " `shard_grad_op auto_wrap`."
+ ),
+ },
+ )
+ fsdp_min_num_params: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is"
+ " passed)."
+ )
},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
- "help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already loaded json file as a dict"
+ "help": (
+ "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already"
+ " loaded json file as a dict"
+ )
},
)
label_smoothing_factor: float = field(
@@ -706,15 +832,19 @@ class TrainingArguments:
ddp_find_unused_parameters: Optional[bool] = field(
default=None,
metadata={
- "help": "When using distributed training, the value of the flag `find_unused_parameters` passed to "
- "`DistributedDataParallel`."
+ "help": (
+ "When using distributed training, the value of the flag `find_unused_parameters` passed to "
+ "`DistributedDataParallel`."
+ )
},
)
ddp_bucket_cap_mb: Optional[int] = field(
default=None,
metadata={
- "help": "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
- "`DistributedDataParallel`."
+ "help": (
+ "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
+ "`DistributedDataParallel`."
+ )
},
)
dataloader_pin_memory: bool = field(
@@ -733,14 +863,14 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
- hub_model_id: str = field(
+ hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
- hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field(
default=False,
@@ -754,21 +884,72 @@ class TrainingArguments:
# Deprecated arguments
fp16_backend: str = field(
default="auto",
- metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
+ metadata={
+ "help": "Deprecated. Use half_precision_backend instead",
+ "choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
+ },
)
- push_to_hub_model_id: str = field(
+ push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
- push_to_hub_organization: str = field(
+ push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
- push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ push_to_hub_token: Optional[str] = field(
+ default=None, metadata={"help": "The token to use to push to the Model Hub."}
+ )
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
)
+ auto_find_batch_size: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to automatically decrease the batch size in half and rerun the training loop again each time"
+ " a CUDA Out-of-Memory was reached"
+ )
+ },
+ )
+ full_determinism: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed"
+ " training"
+ )
+ },
+ )
+ torchdynamo: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
+ " make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
+ " before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
+ " and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
+ " are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
+ " nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
+ ),
+ "choices": ["eager", "nvfuser"],
+ },
+ )
+ ray_scope: Optional[str] = field(
+ default="last",
+ metadata={
+ "help": (
+ 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
+ " will then use the last checkpoint of all trials, compare those, and select the best one. However,"
+ " other options are also available. See the Ray documentation"
+ " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
+ "#ray.tune.ExperimentAnalysis.get_best_trial)"
+ " for more options."
+ )
+ },
+ )
+
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
# This needs to happen before any call to self.device or self.n_gpu.
@@ -795,7 +976,8 @@ def __post_init__(self):
if isinstance(self.evaluation_strategy, EvaluationStrategy):
warnings.warn(
- "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of š¤ Transformers. Use `IntervalStrategy` instead",
+ "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5"
+ " of š¤ Transformers. Use `IntervalStrategy` instead",
FutureWarning,
)
# Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
@@ -817,7 +999,8 @@ def __post_init__(self):
self.eval_steps = self.logging_steps
else:
raise ValueError(
- f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or --logging_steps"
+ f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or"
+ " --logging_steps"
)
# logging_steps must be non-zero for logging_strategy that is other than 'no'
@@ -846,20 +1029,25 @@ def __post_init__(self):
if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
- "`fp16_backend` is deprecated and will be removed in version 5 of š¤ Transformers. Use `half_precision_backend` instead",
+ "`fp16_backend` is deprecated and will be removed in version 5 of š¤ Transformers. Use"
+ " `half_precision_backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
- if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available():
- raise ValueError("Your setup doesn't support bf16. You need Ampere GPU, torch>=1.10, cuda>=11.0")
+ if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available() and not self.no_cuda:
+ raise ValueError(
+ "Your setup doesn't support bf16. You need torch>=1.10, using Ampere GPU with cuda>=11.0 or using CPU"
+ " (no_cuda)"
+ )
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
- " `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead"
+ " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
+ " `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
@@ -867,7 +1055,8 @@ def __post_init__(self):
self.optim = OptimizerNames(self.optim)
if self.adafactor:
warnings.warn(
- "`--adafactor` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--optim adafactor` instead",
+ "`--adafactor` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--optim"
+ " adafactor` instead",
FutureWarning,
)
self.optim = OptimizerNames.ADAFACTOR
@@ -876,10 +1065,23 @@ def __post_init__(self):
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
- and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
+ and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
- "Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
+ "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
+ " (`--fp16_full_eval`) can only be used on CUDA devices."
+ )
+
+ if (
+ is_torch_available()
+ and (self.device.type != "cuda")
+ and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
+ and (self.device.type != "cpu")
+ and (self.bf16 or self.bf16_full_eval)
+ ):
+ raise ValueError(
+ "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
+ " (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if is_torch_available() and self.tf32 is not None:
@@ -914,7 +1116,8 @@ def __post_init__(self):
raise ValueError("warmup_ratio must lie in range [0,1]")
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
logger.info(
- "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
+ "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
+ " during training"
)
if isinstance(self.sharded_ddp, bool):
@@ -931,9 +1134,25 @@ def __post_init__(self):
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
+ if isinstance(self.fsdp, bool):
+ self.fsdp = "full_shard" if self.fsdp else ""
+ if isinstance(self.fsdp, str):
+ self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
+ if self.fsdp == [FSDPOption.OFFLOAD]:
+ raise ValueError(
+ "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
+ '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
+ )
+ elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp:
+ raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
+
+ if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
+ warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
+
if self.tpu_metrics_debug:
warnings.warn(
- "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of š¤ Transformers. Use `--debug tpu_metrics_debug` instead",
+ "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of š¤ Transformers. Use"
+ " `--debug tpu_metrics_debug` instead",
FutureWarning,
)
self.debug += " tpu_metrics_debug"
@@ -1041,6 +1260,10 @@ def _setup_devices(self) -> "torch.device":
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
+ self.local_rank = get_int_from_env(
+ ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"],
+ self.local_rank,
+ )
if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"):
@@ -1048,7 +1271,30 @@ def _setup_devices(self) -> "torch.device":
"CPU distributed training backend is not properly set. "
"Please set '--xpu_backend' to either 'mpi' or 'ccl'."
)
- torch.distributed.init_process_group(backend=self.xpu_backend)
+ if self.xpu_backend == "ccl" and int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
+ raise ValueError(
+ "CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
+ "Please use like 'export CCL_WORKER_COUNT = 1' to set."
+ )
+
+ # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
+ rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
+ size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
+ local_size = get_int_from_env(
+ ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
+ )
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(size)
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+ if not os.environ.get("MASTER_PORT", None):
+ os.environ["MASTER_PORT"] = "29500"
+ if not os.environ.get("MASTER_ADDR", None):
+ if local_size != size or self.xpu_backend != "mpi":
+ raise ValueError(
+ "Looks like distributed multinode run but MASTER_ADDR env not set, "
+ "please try exporting rank 0's hostname as MASTER_ADDR"
+ )
+ torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
diff --git a/src/transformers/training_args_seq2seq.py b/src/transformers/training_args_seq2seq.py
index ef3ccdf2601739..026dce81bcfddf 100644
--- a/src/transformers/training_args_seq2seq.py
+++ b/src/transformers/training_args_seq2seq.py
@@ -51,14 +51,18 @@ class Seq2SeqTrainingArguments(TrainingArguments):
generation_max_length: Optional[int] = field(
default=None,
metadata={
- "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `max_length` value of the model configuration."
+ "help": (
+ "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `max_length` value of the model configuration."
+ )
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
- "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `num_beams` value of the model configuration."
+ "help": (
+ "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+ "to the `num_beams` value of the model configuration."
+ )
},
)
diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py
index 4f3c41e2cab2b3..060b78e9220518 100644
--- a/src/transformers/training_args_tf.py
+++ b/src/transformers/training_args_tf.py
@@ -14,7 +14,7 @@
import warnings
from dataclasses import dataclass, field
-from typing import Tuple
+from typing import Optional, Tuple
from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
@@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""
- tpu_name: str = field(
+ tpu_name: Optional[str] = field(
default=None,
metadata={"help": "Name of TPU"},
)
- tpu_zone: str = field(
+ tpu_zone: Optional[str] = field(
default=None,
metadata={"help": "Zone of TPU"},
)
- gcp_project: str = field(
+ gcp_project: Optional[str] = field(
default=None,
metadata={"help": "Name of Cloud TPU-enabled project"},
)
@@ -195,8 +195,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
# Set to float16 at first
if self.fp16:
- policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
- tf.keras.mixed_precision.experimental.set_policy(policy)
+ tf.keras.mixed_precision.set_global_policy("mixed_float16")
if self.no_cuda:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
@@ -217,8 +216,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
if tpu:
# Set to bfloat16 in case of TPU
if self.fp16:
- policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
- tf.keras.mixed_precision.experimental.set_policy(policy)
+ tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 6101a924f969a0..fea13ff47cc856 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -38,6 +38,7 @@
TensorType,
cached_property,
find_labels,
+ flatten_dict,
is_tensor,
to_numpy,
to_py_obj,
@@ -73,6 +74,7 @@
is_local_clone,
is_offline_mode,
is_remote_url,
+ send_example_telemetry,
url_to_filename,
)
from .import_utils import (
@@ -83,7 +85,9 @@
USE_TF,
USE_TORCH,
DummyObject,
+ OptionalDependencyNotAvailable,
_LazyModule,
+ is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available,
@@ -93,6 +97,7 @@
is_flax_available,
is_ftfy_available,
is_in_notebook,
+ is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@@ -127,6 +132,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
+ is_torchdynamo_available,
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
@@ -168,8 +174,6 @@ def check_min_version(min_version):
error_message += f" but the version found is {__version__}.\n"
raise ImportError(
error_message
- + (
- "Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other "
- "versions of HuggingFace Transformers."
- )
+ + "Check out https://huggingface.co/transformers/examples.html for the examples corresponding to other "
+ "versions of HuggingFace Transformers."
)
diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py
index 6311437b8e3fe6..44c3b1cf3e4b09 100644
--- a/src/transformers/utils/dummy_flax_objects.py
+++ b/src/transformers/utils/dummy_flax_objects.py
@@ -326,6 +326,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxBertForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxBertForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
@@ -389,6 +396,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxBigBirdForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxBigBirdForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
@@ -578,6 +592,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxElectraForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxElectraForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
@@ -704,6 +725,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxLongT5Model(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxLongT5PreTrainedModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxMarianModel(metaclass=DummyObject):
_backends = ["flax"]
@@ -774,6 +816,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
@@ -795,6 +858,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxRobertaForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxRobertaForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 898848d5ba16e4..847e7d87abee40 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -234,6 +234,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class TypicalLogitsWarper(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"]
@@ -402,6 +409,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_VISION_2_SEQ_MAPPING = None
+MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
+
+
MODEL_MAPPING = None
@@ -569,6 +579,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class AutoModelWithLMHead(metaclass=DummyObject):
_backends = ["torch"]
@@ -959,6 +976,44 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class BloomForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1216,6 +1271,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+CVT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class CvtForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CvtModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class CvtPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1406,6 +1485,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class DebertaV2ForMultipleChoice(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class DebertaV2ForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
@@ -1780,6 +1866,58 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class FlavaForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaImageCodebook(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaImageModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaMultimodalModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class FlavaTextModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
FNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2052,6 +2190,37 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs):
requires_backends(load_tf_weights_in_gpt_neo, ["torch"])
+GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class GPTNeoXForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXLayer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class GPTNeoXPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2284,6 +2453,44 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LayoutLMv3ForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LayoutLMv3ForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LayoutLMv3ForTokenClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LayoutLMv3Model(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LayoutLMv3PreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2322,6 +2529,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LevitForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LevitForImageClassificationWithTeacher(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LevitModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LevitPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2381,6 +2619,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class LongT5EncoderModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LongT5ForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LongT5Model(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class LongT5PreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2586,6 +2855,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class MCTCTForCTC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MCTCTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MCTCTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2938,6 +3231,30 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs):
requires_backends(load_tf_weights_in_openai_gpt, ["torch"])
+OPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class OPTForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OPTModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class OPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class PegasusForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
@@ -3785,6 +4102,13 @@ def __init__(self, *args, **kwargs):
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+class SplinterForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class SplinterForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
@@ -3938,6 +4262,23 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"])
+TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TrajectoryTransformerModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class TrajectoryTransformerPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -4357,6 +4698,58 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class Wav2Vec2ConformerForAudioFrameClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForCTC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForPreTraining(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerForXVector(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Wav2Vec2ConformerPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -4697,6 +5090,30 @@ def load_tf_weights_in_xlnet(*args, **kwargs):
requires_backends(load_tf_weights_in_xlnet, ["torch"])
+YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class YolosForObjectDetection(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class YolosModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class YolosPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py
index 37d52fc0943388..00989dc0d12a4c 100644
--- a/src/transformers/utils/dummy_sentencepiece_objects.py
+++ b/src/transformers/utils/dummy_sentencepiece_objects.py
@@ -45,6 +45,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
+class CpmTokenizer(metaclass=DummyObject):
+ _backends = ["sentencepiece"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["sentencepiece"])
+
+
class DebertaV2Tokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
@@ -52,6 +59,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
+class FNetTokenizer(metaclass=DummyObject):
+ _backends = ["sentencepiece"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["sentencepiece"])
+
+
class LayoutXLMTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
diff --git a/src/transformers/utils/dummy_speech_objects.py b/src/transformers/utils/dummy_speech_objects.py
index 721fe80a7925d0..ae5589292a4cf9 100644
--- a/src/transformers/utils/dummy_speech_objects.py
+++ b/src/transformers/utils/dummy_speech_objects.py
@@ -3,6 +3,13 @@
from ..utils import DummyObject, requires_backends
+class MCTCTFeatureExtractor(metaclass=DummyObject):
+ _backends = ["speech"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["speech"])
+
+
class Speech2TextFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py
index cbd3358ac8ffa2..4eb40113e76cf0 100644
--- a/src/transformers/utils/dummy_tf_objects.py
+++ b/src/transformers/utils/dummy_tf_objects.py
@@ -261,6 +261,9 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
+TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
+
+
TF_MODEL_FOR_MASKED_LM_MAPPING = None
@@ -335,6 +338,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFAutoModelForNextSentencePrediction(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFAutoModelForPreTraining(metaclass=DummyObject):
_backends = ["tf"]
@@ -742,6 +752,34 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFData2VecVisionForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFData2VecVisionForSemanticSegmentation(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFData2VecVisionModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFData2VecVisionPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1588,6 +1626,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
@@ -1859,6 +1918,37 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class TFSwinForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinForMaskedImageModeling(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFSwinPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py
index 64c754164969b2..631df9f25890c8 100644
--- a/src/transformers/utils/dummy_tokenizers_objects.py
+++ b/src/transformers/utils/dummy_tokenizers_objects.py
@@ -52,6 +52,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class BloomTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class CamembertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -73,6 +80,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class CpmTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class DebertaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -143,6 +157,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class GPTNeoXTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class HerbertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
@@ -164,6 +185,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class LayoutLMv3TokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class LayoutXLMTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 6ffeeb52b3e8fc..63e7450be47369 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -59,6 +59,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class FlavaFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class FlavaProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class GLPNFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
@@ -80,14 +94,14 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class LayoutLMv2Processor(metaclass=DummyObject):
+class LayoutLMv3FeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class LayoutXLMProcessor(metaclass=DummyObject):
+class LevitFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
@@ -141,3 +155,10 @@ class ViTFeatureExtractor(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+
+
+class YolosFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 0cfb0b10d5acb1..62f0c98b4ccc9d 100644
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -13,39 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import builtins
+import collections
import functools
import inspect
import math
+import operator
import random
-from types import ModuleType
-from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from packaging import version
from torch import nn
-from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
-from torch.fx.node import Argument
-
-from .. import (
- CONFIG_MAPPING,
- MODEL_FOR_CAUSAL_LM_MAPPING,
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
- MODEL_FOR_MASKED_LM_MAPPING,
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
- MODEL_FOR_PRETRAINING_MAPPING,
- MODEL_FOR_QUESTION_ANSWERING_MAPPING,
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
- MODEL_MAPPING,
- GPT2DoubleHeadsModel,
- PretrainedConfig,
- PreTrainedModel,
- XLNetForQuestionAnswering,
- logging,
-)
+from torch.fx import Graph, GraphModule, Proxy, Tracer
+from torch.fx.proxy import ParameterProxy
+
+from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
+from ..models.auto.modeling_auto import (
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_CTC_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
+ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
+ MODEL_FOR_PRETRAINING_MAPPING_NAMES,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ MODEL_MAPPING_NAMES,
+)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata
@@ -53,24 +55,27 @@
logger = logging.get_logger(__name__)
-def _generate_supported_model_classes(
+def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
-) -> List[Type[PreTrainedModel]]:
+) -> List[str]:
- model_config_class = CONFIG_MAPPING[model_name]
task_mapping = {
- "default": MODEL_MAPPING,
- "pretraining": MODEL_FOR_PRETRAINING_MAPPING,
- "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
- "masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
- "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
- "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
- "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
- "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
- "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
- "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
- "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ "default": MODEL_MAPPING_NAMES,
+ "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
+ "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
+ "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
+ "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
+ "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
+ "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
+ "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
+ "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
+ "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
+ "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
+ "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
+ "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
+ "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
}
if supported_tasks is None:
@@ -78,150 +83,544 @@ def _generate_supported_model_classes(
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
- model_classes = []
+ model_class_names = []
for task in supported_tasks:
- model_class = task_mapping[task].get(model_config_class, None)
- if model_class:
- model_classes.append(model_class)
+ class_name = task_mapping[task].get(model_name, None)
+ if class_name:
+ model_class_names.append(class_name)
- return model_classes
+ return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert",
+ "bart",
"bert",
+ "blenderbot",
+ "blenderbot-small",
+ "clip",
+ "deberta",
+ "deberta-v2",
"distilbert",
- "mobilebert",
"electra",
- "megatron-bert",
"gpt2",
- "gptj",
"gpt_neo",
- "t5",
+ "gptj",
+ "hubert",
+ "layoutlm",
+ "lxmert",
+ "m2m_100",
+ "marian",
+ "mbart",
+ "megatron-bert",
+ "mobilebert",
+ "mt5",
+ "opt",
+ "pegasus",
+ "plbart",
"roberta",
- # TODO: add support for them as it should be quite easy to do so (small blocking issues).
- # "layoutlm",
+ "speech_to_text",
+ "speech_to_text_2",
+ "swin",
+ "t5",
+ "trocr",
+ "vit",
+ "xglm",
# "xlnet",
+ # TODO: add support for them as it should be quite easy to do so (small blocking issues).
]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
- _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
+ _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else:
- _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
+ _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [
- GPT2DoubleHeadsModel,
+ "CLIPTextModel",
+ "CLIPVisionModel",
+ "GPT2DoubleHeadsModel",
+ "Speech2Text2Decoder",
+ "TrOCRDecoder",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
-_SUPPORTED_MODELS = tuple(
- sorted(list(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)), key=lambda c: c.__name__)
-)
+_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
+
+
+def torch_nn_embedding(self, input):
+ return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
+
+
+def torch_nn_functional_embedding(
+ input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
+):
+ return torch.empty(*input.shape, weight.shape[-1], device="meta")
+
+
+def torch_nn_layernorm(self, input):
+ return input
+
+
+def torch_nn_groupnorm(self, input):
+ return input
+
+
+def torch_nn_linear(self, input):
+ return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
+
+
+def torch_relu(x):
+ return x
+
+
+def torch_nn_relu(self, x):
+ return x
+
+
+def torch_nn_functional_relu(x, inplace=False):
+ if not inplace:
+ raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
+ return x
+
+
+def torch_where(condition, x, y):
+ # torch.where returns the broadcasted tensor of condition, x, and y,
+ # so hack it by using addition
+ return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
+
+
+def torch_abs(input, *, out=None):
+ if out is not None:
+ raise ValueError("Don't support in-place abs for MetaTensor analysis")
+ return input
+
+
+def torch_arange(*args, **kwargs):
+ n = len(args)
+ step = 1
+ if n == 1:
+ start = 0
+ end = args[0]
+ elif n == 2:
+ start, end = args
+ else:
+ start, end, step = args
+ if isinstance(start, float):
+ start = int(start)
+ if isinstance(end, float):
+ start = int(end)
+ if isinstance(step, float):
+ step = int(step)
+ step = kwargs.get("step", step)
+ dtype = kwargs.get("dtype")
+ return torch.empty((end - start) // step, dtype=dtype, device="meta")
+
+
+def torch_cat(tensors, dim=None, axis=None, *, out=None):
+ if dim is None and axis is None:
+ dim = 0
+ if dim is None and axis is not None:
+ dim = axis
+ if dim < 0:
+ dim = tensors[0].dim() + dim
+ shapes = [t.shape for t in tensors]
+ shape = list(shapes[0])
+ concatenated_dim = sum(shape[dim] for shape in shapes)
+ final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
+ return torch.empty(final_shape, device="meta")
+
+
+def torch_stack(tensors, dim=None, axis=None, *, out=None):
+ if dim is None and axis is None:
+ dim = 0
+ if dim is None and axis is not None:
+ dim = axis
+ if dim < 0:
+ dim = tensors[0].dim() + 1 + dim
+ shape = list(tensors[0].shape)
+ shape.insert(dim, len(tensors))
+ return torch.empty(shape, device="meta")
+
+
+def torch_add(input, other, *, alpha=1, out=None):
+ if not isinstance(input, torch.Tensor):
+ return torch.empty_like(other, device="meta")
+ if not isinstance(other, torch.Tensor):
+ return torch.empty_like(input, device="meta")
+ max_length = max(input.dim(), other.dim())
+ input_shape = list(input.shape) + [1] * (max_length - input.dim())
+ other_shape = list(other.shape) + [1] * (max_length - other.dim())
+ shape = []
+ for i in range(max_length):
+ shape.append(max(input_shape[i], other_shape[i]))
+ return torch.empty(shape, device="meta")
+
+
+def torch_mul(input, other, *, out=None):
+ return torch_add(input, other, out=out)
+
+
+def torch_tensor_mul(self, other):
+ return torch_mul(self, other)
+
+
+def torch_matmul(input, other, *, out=None):
+ d1 = input.dim()
+ d2 = other.dim()
+ shape = None
+ if d1 == 1 and d2 == 1:
+ shape = None
+ elif d1 == 2 and d2 == 2:
+ shape = (input.size(0), other.size(1))
+ elif d1 == 1 and d2 == 2:
+ shape = (other.size(1),)
+ elif d1 == 2 and d1 == 1:
+ shape = (input.size(0),)
+ else:
+ max_length = max(input.dim(), other.dim())
+ shape1 = list(input.shape)
+ shape2 = list(other.shape)
+ if d1 == 1:
+ shape1 = [1] + shape1
+ if d2 == 1:
+ shape2.append(1)
+ shape1 = [-1] * (max_length - d1) + list(input.shape)
+ shape2 = [-1] * (max_length - d2) + list(other.shape)
+ shape = []
+ for i in range(max_length):
+ shape.append(max(shape1[i], shape2[i]))
+ shape[-2] = shape1[-2]
+ shape[-1] = shape2[-1]
+ if d1 == 1:
+ shape.pop(-2)
+ if d2 == 1:
+ shape.pop(-1)
+ if shape is None:
+ return torch.tensor(0.0, device="meta")
+ return torch.empty(*shape, device="meta")
+
+
+def torch_bmm(input, mat2, *, out=None):
+ if out is not None:
+ raise ValueError("Don't support in-place abs for MetaTensor analysis")
+ batch_size, n, m = input.shape
+ _, _, p = mat2.shape
+ return torch.empty(batch_size, n, p, device="meta")
+
+
+def torch_einsum(equation, *operands):
+ # TODO: infer shape without performing the computation, this might be quite hard.
+ concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
+ return torch.einsum(equation, *concrete_operands).to("meta")
+
+
+def torch_tensor_repeat(self, *sizes):
+ shape = list(self.shape)
+ for i, x in enumerate(sizes):
+ shape[i] *= x
+ return torch.empty(shape, device="meta")
+
+
+def torch_index_select(input, dim, index, *, out=None):
+ shape = list(input.shape)
+ shape[dim] = len(index)
+ return torch.empty(*shape, device="meta")
+
+
+def torch_tensor_index_select(self, dim, index):
+ return torch_index_select(self, dim, index)
+
+
+def torch_roll(input, shifts, dims=None):
+ return input
+
+
+def torch_flip(input, dims):
+ return input
+
+
+def torch_tensor_flip(self, dims):
+ return self
+
+
+def torch_nn_conv1d(self, input):
+ l_in = input.shape[-1]
+ shape = None
+ padding = self.padding
+ if padding == "valid":
+ padding = (0, 0)
+ if padding == "same":
+ shape = list(input.shape)
+ if shape is None:
+ shape = list(input.shape)
+ l_out = math.floor(
+ (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ shape[-1] = l_out
+ shape[-2] = self.out_channels
+ return torch.empty(shape, device="meta")
+
+
+def torch_nn_conv2d(self, input):
+ h_in, w_in = input.shape[-2:]
+ shape = None
+ padding = self.padding
+ if padding == "valid":
+ padding = (0, 0)
+ if padding == "same":
+ shape = list(input.shape)
+ if shape is None:
+ shape = list(input.shape)
+ h_out = math.floor(
+ (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ shape[-2:] = [h_out, w_out]
+ shape[-3] = self.out_channels
+ return torch.empty(shape, device="meta")
+
+
+def torch_squeeze(input, dim=None):
+ shape = list(input.shape)
+ if dim is not None:
+ if dim < 0:
+ dim = input.dim() + dim
+ if shape[dim] == 1:
+ shape.pop(dim)
+ else:
+ new_shape = []
+ for dim_value in shape:
+ if dim_value == 1:
+ continue
+ new_shape.append(dim_value)
+ shape = new_shape
+ return torch.empty(shape, device="meta")
+
+
+def torch_tensor_squeeze(self, dim=None):
+ return torch_squeeze(self, dim)
+
+
+def torch_unsqueeze(input, dim):
+ shape = list(input.shape)
+ if dim < 0:
+ dim = input.dim() + 1 + dim
+ shape.insert(dim, 1)
+ return torch.empty(shape, device="meta")
+
+
+def torch_tensor_unsqueeze(self, dim):
+ return torch_unsqueeze(self, dim)
+
+
+def torch_unique_consecutive(input, **kwargs):
+ output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
+ if isinstance(output, torch.Tensor):
+ return output.to("meta")
+ else:
+ return tuple(map(output, lambda x: x.to("meta")))
+
+
+def torch_nn_functional_one_hot(tensor, num_classes=-1):
+ if num_classes < 0:
+ raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
+ shape = list(tensor.shape) + [num_classes]
+ return torch.empty(shape, device="meta")
+
+
+def torch_nn_mseloss(self, input, target):
+ if self.reduction == "none":
+ shape = target.shape
+ else:
+ shape = (1,)
+ return torch.empty(shape, device="meta")
+
+
+def torch_nn_crossentropyloss(self, input, target):
+ if self.reduction == "none":
+ shape = target.shape
+ else:
+ shape = (1,)
+ return torch.empty(shape, device="meta")
+
+
+def torch_nn_bcewithlogitsloss(self, input, target):
+ if self.reduction == "none":
+ shape = target.shape
+ else:
+ shape = (1,)
+ return torch.empty(shape, device="meta")
+
+
+def operator_getitem(a, b):
+ def to_concrete(t):
+ if isinstance(t, torch.Tensor):
+ concrete = torch.ones_like(t, device="cpu")
+ if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
+ concrete = concrete.to(torch.int64)
+ return concrete
+ return t
+
+ if isinstance(a, torch.Tensor):
+ # TODO: infer shape without performing the computation.
+ if isinstance(b, tuple):
+ b = tuple(map(to_concrete, b))
+ else:
+ b = to_concrete(b)
+ return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
+ return operator.getitem(a, b)
+
+
+_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
+ torch.nn.Embedding: torch_nn_embedding,
+ torch.nn.functional.embedding: torch_nn_functional_embedding,
+ torch.nn.LayerNorm: torch_nn_layernorm,
+ torch.nn.GroupNorm: torch_nn_groupnorm,
+ torch.nn.Linear: torch_nn_linear,
+ torch.relu: torch_relu,
+ torch.nn.functional.relu: torch_nn_functional_relu,
+ torch.nn.ReLU: torch_nn_relu,
+ torch.where: torch_where,
+ torch.abs: torch_abs,
+ torch.arange: torch_arange,
+ torch.cat: torch_cat,
+ torch.stack: torch_stack,
+ torch.add: torch_add,
+ torch.mul: torch_mul,
+ torch.Tensor.mul: torch_tensor_mul,
+ torch.matmul: torch_matmul,
+ torch.bmm: torch_bmm,
+ torch.einsum: torch_einsum,
+ torch.Tensor.repeat: torch_tensor_repeat,
+ torch.roll: torch_roll,
+ torch.flip: torch_flip,
+ torch.Tensor.flip: torch_tensor_flip,
+ torch.index_select: torch_index_select,
+ torch.Tensor.index_select: torch_tensor_index_select,
+ torch.nn.Conv1d: torch_nn_conv1d,
+ torch.nn.Conv2d: torch_nn_conv2d,
+ torch.squeeze: torch_squeeze,
+ torch.Tensor.squeeze: torch_tensor_squeeze,
+ torch.unsqueeze: torch_unsqueeze,
+ torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
+ torch.unique_consecutive: torch_unique_consecutive,
+ torch.nn.functional.one_hot: torch_nn_functional_one_hot,
+ torch.nn.MSELoss: torch_nn_mseloss,
+ torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
+ torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
+ operator.getitem: operator_getitem,
+}
class HFProxy(Proxy):
"""
- Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
- the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
- existing ones.
+ Proxy that uses metadata to handle data-dependent control-flow.
"""
- def __init__(self, node: Node, tracer: Optional[Tracer] = None):
- super().__init__(node, tracer=tracer)
- if hasattr(self, "tracer") and self.tracer is not None:
- self.device = self.tracer.root.device
- self.dtype = next(self.tracer.root.parameters()).dtype
- self.cache = None
+ def install_metadata(self, metadata):
+ self._metadata = metadata
@property
def shape(self):
- return self.size()
+ return self.tracer.create_proxy("call_method", "size", (self,), {})
- def __setitem__(self, key, value):
- pass
+ @property
+ def dtype(self):
+ if hasattr(self, "_metadata") and self._metadata is not None:
+ return self._metadata.dtype
+ return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
- def __contains__(self, key):
- return False
+ @property
+ def device(self):
+ # Hack so we can track when devices are used. During meta-tensor propagation,
+ # replace these values with a constant 'meta'
+ return MetaDeviceAttribute(self, "device")
- def __eq__(self, other):
- if self.cache is not None:
- return self.cache == other
- elif isinstance(other, HFProxy):
- return True
- else:
- return super().__eq__(other)
+ def __len__(self):
+ if hasattr(self, "_metadata") and self._metadata is not None:
+ return len(self._metadata)
+ return super().__len__()
- def __ne__(self, other):
- return not self == other
+ def __bool__(self):
+ if hasattr(self, "_metadata") and self._metadata is not None:
+ return self._metadata
+ return super().__bool__()
- def __len__(self):
- if self.cache is not None:
- if isinstance(self.cache, int):
- return self.cache
- elif isinstance(self.cache, (torch.Size, list, tuple)):
- return len(self.cache)
- else:
- return super().__len__(self)
- return super().__len__(self)
+ def __getattr__(self, k):
+ if k == "_metadata":
+ return self.__getattribute__(k)
+ # note: not added to the graph yet, if this is a method call
+ # we peephole optimize to the method invocation
+ return HFAttribute(self, k)
- def __torch_function__(self, orig_method, types, args=None, kwargs=None):
- proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs)
- proxy.cache = self.cache
- return proxy
+ def __setitem__(self, indices, values):
+ return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
+ def __contains__(self, key):
+ # To handle cases such as :
+ # `"some_key" in kwargs`
+ if self.node.op == "placeholder":
+ return False
+ return super().__contains__(key)
-def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]:
- """Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer."""
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- return func(*args, **kwargs)
+class HFAttribute(HFProxy):
+ def __init__(self, root, attr: str):
+ self.root = root
+ self.attr = attr
+ self.tracer = root.tracer
+ self._node = None
- return wrapper
+ @property
+ def node(self):
+ # the node for attributes is added lazily, since most will just be method calls
+ # which do not rely on the getitem call
+ if self._node is None:
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
+ return self._node
+ def __call__(self, *args, **kwargs):
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
-def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]:
- @functools.wraps(mapping[func_name])
- def wrapper(*args, **kwargs):
- return mapping[func_name](*args, **kwargs)
- return wrapper
+class MetaDeviceAttribute(HFAttribute):
+ pass
-def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool):
- """
- Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
- during symbolic tracing.
- """
+def _proxies_to_metas(v):
+ """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
+ if isinstance(v, MetaDeviceAttribute):
+ return "meta"
+ if isinstance(v, torch.fx.Proxy):
+ if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
+ raise RuntimeError(f"No metadata was found for {v}")
+ return v._metadata
+ return v
- original_method = getattr(torch.Tensor, method_name)
-
- @functools.wraps(original_method)
- def method(*args, **kwargs):
- cache = getattr(args[0].tracer.root, cache_name)
- res = cache.pop(0)
- if return_proxy:
- proxy = args[0].__torch_function__(
- original_method,
- None,
- args=args,
- kwargs=kwargs,
- )
- proxy.cache = res
- return proxy
- return res
- method.__name__ = method_name
- bound_method = method.__get__(proxy, proxy.__class__)
- setattr(proxy, method_name, bound_method)
+def _gen_constructor_wrapper(target):
+ @functools.wraps(target)
+ def wrapper(*args, **kwargs):
+ proxy = None
+
+ def check_has_proxy(v):
+ if isinstance(v, Proxy):
+ nonlocal proxy
+ proxy = v
+
+ torch.fx.node.map_aggregate(args, check_has_proxy)
+ torch.fx.node.map_aggregate(kwargs, check_has_proxy)
+ if proxy is not None:
+ return proxy.tracer.create_proxy("call_function", target, args, kwargs)
+ else:
+ return target(*args, **kwargs)
-def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]):
- """Helper function that resets the monkey patched torch.Tensor methods to their original values."""
- for name, method in original_methods.items():
- setattr(torch.Tensor, name, method)
+ return wrapper, target
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
@@ -239,30 +638,14 @@ class HFTracer(Tracer):
regular PyTorch torch.fx.Proxy.
"""
- _DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False}
- from transformers import modeling_utils
-
- _FUNCTIONS_TO_AUTOWRAP = {
- torch: {"arange", "zeros", "ones", "full_like", "eye"},
- modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"},
- }
-
- def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False):
-
- # Loading the leaf functions register
- self._leaf_functions_register = {}
- for module, names in self._FUNCTIONS_TO_AUTOWRAP.items():
- for name in names:
- self._register_leaf_function(module, name)
+ # Feature flag for proxying accesses to buffer values
+ proxy_buffer_attributes: bool = True
+ allow_insert_stateless_mods: bool = True
+ _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
- # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
- # autowrap_functions = autowrap_functions + tuple(
- # patched for (_, _, patched) in self._leaf_functions_register.values()
- # )
+ def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
- super().__init__(
- autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching
- )
+ super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
@@ -271,127 +654,117 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatc
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
- self.prev_module = None
- self.recorded_methods = None
-
- def _register_leaf_function(self, module: ModuleType, name: str):
- """Registers the function called name in module as a leaf function."""
- orig_func = getattr(module, name)
- patched_func = _function_to_leaf(orig_func)
- patched_func.__module__ = __name__
- self._leaf_functions_register[name] = (module, orig_func, patched_func)
-
- def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False):
- """Patches leaf functions specifically for root."""
- for name in self._leaf_functions_register:
- module, orig_func, patched_func = self._leaf_functions_register[name]
- if restore:
- root.__class__.forward.__globals__.pop(name)
- setattr(module, name, orig_func)
- else:
- root.__class__.forward.__globals__[name] = patched_func
- leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__)
- leaf_getter.__module__ = __name__
- setattr(module, name, leaf_getter)
-
- def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool:
- """
- Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record
- outputs that will not be encountered by the tracer.
- """
-
- currentframe = inspect.currentframe()
- while currentframe:
- if currentframe is None:
- return False
- module = currentframe.f_locals.get("self", None)
- if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"):
- return True
- currentframe = currentframe.f_back
- return False
-
- def _wrap_method_for_model_recording(
- self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int]
- ):
- """Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
- method = getattr(torch.Tensor, method_name)
-
- @functools.wraps(method)
- def wrapped(*args, **kwargs):
- if self._method_is_called_in_leaf_module(module_ids):
- return method(*args, **kwargs)
- if not hasattr(model, cache_name):
- setattr(model, cache_name, [])
- cache = getattr(model, cache_name)
- res = method(*args, **kwargs)
- cache.append(res)
- return res
-
- return wrapped
-
- def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]):
- """
- Helper function that patches torch.Tensor methods (specified by the method_names list) to record model
- inference before symbolic tracing.
- """
- cache_names = {}
- original_methods = {}
- module_ids = set(id(mod) for mod in model.modules())
- for method_name in method_names:
- cache_name = f"cache_{method_name}"
- cache_names[method_name] = cache_name
- if not hasattr(torch.Tensor, method_name):
- logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
- continue
- original_methods[method_name] = getattr(torch.Tensor, method_name)
- setattr(
- torch.Tensor,
- method_name,
- self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids),
- )
-
- if method_name == "size":
- original_methods["shape"] = torch.Tensor.shape
- setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))
-
- return cache_names, original_methods
-
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
- model_class = model.__class__
+ # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
+ # from pickle, or from the "__class__" attribute in the general case.
+ model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device
inputs_dict = {}
if input_name in ["labels", "start_positions", "end_positions"]:
+
batch_size = shape[0]
- if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
- XLNetForQuestionAnswering,
+ elif model_class_name in [
+ *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
+ "XLNetForQuestionAnswering",
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
- *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
- *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
+ elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
+ if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
+ raise ValueError(
+ "Could not retrieve the problem type for the sequence classification task, please set "
+ 'model.config.problem_type to one of the following values: "regression", '
+ '"single_label_classification", or "multi_label_classification".'
+ )
+
+ if model.config.problem_type == "regression":
+ labels_shape = (batch_size, model.config.num_labels)
+ labels_dtype = torch.float32
+ elif model.config.problem_type == "single_label_classification":
+ labels_shape = (batch_size,)
+ labels_dtype = torch.long
+ elif model.config.problem_type == "multi_label_classification":
+ labels_shape = (batch_size, model.config.num_labels)
+ labels_dtype = torch.float32
+ else:
+ raise ValueError(
+ 'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
+ f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
+ )
+ inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
+
+ elif model_class_name in [
+ *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
- elif model_class in [
- *get_values(MODEL_FOR_PRETRAINING_MAPPING),
- *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
- *get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
- *get_values(MODEL_FOR_MASKED_LM_MAPPING),
- *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
- GPT2DoubleHeadsModel,
+ elif model_class_name in [
+ *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
+ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
+ *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
+ *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
+ "GPT2DoubleHeadsModel",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
- raise NotImplementedError(f"{model_class} not supported yet.")
-
+ raise NotImplementedError(f"{model_class_name} not supported yet.")
+ elif "pixel_values" in input_name:
+ batch_size = shape[0]
+ image_size = getattr(model.config, "image_size", None)
+ if image_size is None:
+ if hasattr(model.config, "vision_config"):
+ image_size = model.config.vision_config.image_size
+ elif hasattr(model.config, "encoder"):
+ image_size = model.config.encoder.image_size
+ else:
+ raise AttributeError('Could not find the "image_size" field in the model config')
+
+ # If no num_channels is in the config, use some arbitrary value.
+ num_channels = getattr(model.config, "num_channels", 3)
+ if not isinstance(image_size, collections.abc.Iterable):
+ image_size = (image_size, image_size)
+ height, width = image_size
+ inputs_dict[input_name] = torch.zeros(
+ batch_size, num_channels, height, width, dtype=torch.float32, device=device
+ )
+ elif "bbox" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
+ elif "input_features" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
+ )
+ elif "visual_feats" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_feat_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "visual_pos" in input_name:
+ inputs_dict[input_name] = torch.zeros(
+ shape
+ + [
+ model.config.visual_pos_dim,
+ ],
+ dtype=torch.float,
+ device=device,
+ )
+ elif "inputs" in input_name:
+ inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
+ elif "input_values" in input_name:
+ batch_size, _ = shape
+ # Generating big sequence length for audio inputs.
+ seq_length = _generate_random_int(low=10000, high=20000)
+ inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
@@ -400,114 +773,214 @@ def _generate_dummy_input(
return inputs_dict
- def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None):
- """
- Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing.
- """
- if method_names is None:
- method_names = self._DEFAULT_METHODS_TO_RECORD
-
- # Creating a random input shape to generate dummy inputs.
- batch_size = _generate_random_int()
- sequence_length = _generate_random_int()
- shape = [batch_size, sequence_length]
-
- if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
- num_choices = _generate_random_int(low=2, high=5)
- shape.insert(1, num_choices)
-
- inputs = {}
- for input_name in input_names:
- inputs.update(self._generate_dummy_input(model, input_name, shape))
-
- cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names)
- self.original_methods = original_methods
-
- model(**inputs)
+ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
+ rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
+
+ if kind == "placeholder" and target in self.meta_args:
+ rv.install_metadata(self.meta_args[target])
+ return rv
+
+ if target in self.orig_fns:
+ # NOTE: tensor constructors in PyTorch define the `device` argument as
+ # *kwargs-only*. That is why this works. If you add methods to
+ # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
+ # this will break and you will likely see issues where we cannot infer
+ # the size of the output.
+ if "device" in kwargs:
+ kwargs["device"] = "meta"
+
+ try:
+ args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
+ kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
+
+ if kind == "call_function":
+ meta_target = _MANUAL_META_OVERRIDES.get(target, target)
+ meta_out = meta_target(*args_metas, **kwargs_metas)
+ if isinstance(meta_out, torch.Tensor):
+ meta_out = meta_out.to(device="meta")
+ elif kind == "call_method":
+ method = getattr(args_metas[0].__class__, target)
+ meta_target = _MANUAL_META_OVERRIDES.get(method, method)
+ meta_out = meta_target(*args_metas, **kwargs_metas)
+ elif kind == "call_module":
+ if not hasattr(self, "orig_forward"):
+ raise AttributeError(f"{self} does not have an attribute called orig_forward")
+ self._disable_module_getattr = True
+ try:
+ mod = self.root.get_submodule(target)
+ mod_type = type(mod)
+ if mod_type in _MANUAL_META_OVERRIDES:
+ meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
+ else:
+ meta_out = self.orig_forward(*args_metas, **kwargs_metas)
+ finally:
+ self._disable_module_getattr = False
+ elif kind == "get_attr":
+ self._disable_module_getattr = True
+ try:
+ attr_itr = self.root
+ atoms = target.split(".")
+ for atom in atoms:
+ attr_itr = getattr(attr_itr, atom)
+ if isinstance(attr_itr, torch.Tensor):
+ meta_out = attr_itr.to(device="meta")
+ else:
+ meta_out = attr_itr
+ finally:
+ self._disable_module_getattr = False
+ else:
+ return rv
- _reset_tensor_methods(original_methods)
+ if not isinstance(rv, Proxy):
+ raise ValueError("Don't support composite output yet")
+ rv.install_metadata(meta_out)
+ except Exception as e:
+ warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
- self.recorded_methods = {
- method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name)
- }
+ return rv
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
- if isinstance(attr_val, torch.nn.Parameter):
- for n, p in self.root.named_parameters():
- if attr_val is p:
- if n not in parameter_proxy_cache:
- parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
- return parameter_proxy_cache[n]
- # TODO: condition this on wether dynamic axes were requested.
- if isinstance(attr_val, torch.Tensor):
- for n, p in self.root.named_buffers():
- if attr_val is p:
- if n not in parameter_proxy_cache:
- parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
- return parameter_proxy_cache[n]
- return attr_val
-
- def proxy(self, node: Node):
- p = HFProxy(node, self)
- if self.recorded_methods:
- for method_name, cache_name in self.recorded_methods.items():
- return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name]
- _create_recorded_proxy_method(p, method_name, cache_name, return_proxy)
- return p
-
- def trace(
- self,
- root: PreTrainedModel,
- concrete_args: Optional[Dict[str, Any]] = None,
- method_names: Optional[Iterable[str]] = None,
- ) -> Graph:
+ if getattr(self, "_disable_module_getattr", False):
+ return attr_val
+ else:
+ # return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
+ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
+ for n, p in collection_to_search:
+ if attr_val is p:
+ if n not in parameter_proxy_cache:
+ kwargs = {}
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ parameter_proxy_cache[n] = val_proxy
+ return parameter_proxy_cache[n]
+ return None
+
+ if isinstance(attr_val, torch.nn.Parameter):
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
+ if maybe_parameter_proxy is not None:
+ return maybe_parameter_proxy
+
+ if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
+ if maybe_buffer_proxy is not None:
+ return maybe_buffer_proxy
+
+ return attr_val
+
+ def call_module(self, m, forward, args, kwargs):
+ self.orig_forward = forward
+ return super().call_module(m, forward, args, kwargs)
+
+ def proxy(self, node):
+ return HFProxy(node, self)
+
+ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
if concrete_args is None:
concrete_args = {}
sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys()
- self.record(root, input_names, method_names=method_names)
+ # Creating a random input shape to generate dummy inputs.
+ batch_size = _generate_random_int()
+ sequence_length = _generate_random_int()
+ shape = [batch_size, sequence_length]
- # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
- autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()]
- self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
+ if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
+ num_choices = _generate_random_int(low=2, high=5)
+ shape.insert(1, num_choices)
- self._patch_leaf_functions_for_root(root)
+ inputs = {}
+ for input_name in input_names:
+ inputs.update(self._generate_dummy_input(root, input_name, shape))
- self.graph = super().trace(root, concrete_args=concrete_args)
+ concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
+ self.meta_args = concrete_metas
+ self.patched_torch_methods = {
+ target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
+ }
+ self.orig_fns = set()
- self._patch_leaf_functions_for_root(root, restore=True)
+ for name, (wrapper, orig) in self.patched_torch_methods.items():
+ setattr(torch, name, wrapper)
+ self.orig_fns.add(orig)
- _reset_tensor_methods(self.original_methods)
+ try:
+ self.graph = super().trace(root, concrete_args=concrete_args)
+ finally:
+ for name, (_, orig) in self.patched_torch_methods.items():
+ setattr(torch, name, orig)
- # TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
- # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
+ # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
+ # It cannot infer on the attributes and methods the input should have, and fails.
+ node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
+ if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
+ # Newer versions of torch.fx emit an assert statement
+ # for concrete arguments; delete those before we delete
+ # the concrete arg.
+ to_delete = []
+ for user in node.users:
+ if user.target == torch.fx._symbolic_trace._assert_is_none:
+ to_delete.append(user)
+ for user in to_delete:
+ self.graph.erase_node(user)
+
self.graph.erase_node(node)
+ # TODO: solves GraphModule creation.
+ # Without this, return type annotation "Tuple" is causing code execution failure.
+ if node.op == "output":
+ node.type = None
+
return self.graph
+ def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
+ """
+ Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
+ because its attributes are input-dependent.
+ """
+ return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
+
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
+ # If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
+ # It is not possible to insert such modules, those should be traced through.
+ if self._stateless_mod_instanciation_depends_on_proxies(mod):
+ return ""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
+ already_inserted = False
while hasattr(self.root, path):
+ if getattr(self.root, path) is mod:
+ already_inserted = True
+ break
path = f"{mod_name}_{idx}"
idx += 1
- self.root.add_module(path, mod)
+ # No need to add multiple instances of the same module.
+ if not already_inserted:
+ self.root.add_module(path, mod)
return path
def path_of_module(self, mod: nn.Module) -> str:
@@ -519,42 +992,46 @@ def path_of_module(self, mod: nn.Module) -> str:
Args:
mod (str): The `Module` to retrieve the qualified name for.
"""
- # Prefer the O(1) algorithm
- if hasattr(self, "submodule_paths") and self.submodule_paths:
- path = self.submodule_paths.get(mod)
- if path is None:
+ try:
+ return super().path_of_module(mod)
+ except NameError as e:
+ if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
- if path is None:
- raise NameError(f"Module named {mod._get_name()} is not installed as a submodule")
- self.prev_module = path
- return path
+ return path
+ raise e
- # O(N^2) fallback in the case that we didn't store the submodule
- # paths.
- else:
- for n, p in self.root.named_modules():
- if mod is p:
- self.prev_module = n
- return n
- path = self._insert_module_as_submodule(mod)
- if path is None:
- raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
- self.prev_module = path
- return path
-
- def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
- is_loss_module = m.__module__.startswith("torch.nn.modules.loss")
- return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name)
-
- def create_arg(self, a: Any) -> Argument:
- if isinstance(a, range):
- return super().create_arg(list(a))
- return super().create_arg(a)
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+ return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
+ m, module_qualified_name
+ )
+
+
+def get_concrete_args(model: nn.Module, input_names: List[str]):
+ sig = inspect.signature(model.forward)
+
+ if not (set(input_names) <= set(sig.parameters.keys())):
+ formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
+ formatted_allowed_input_names = ", ".join(sig.parameters.keys())
+ raise ValueError(
+ f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
+ f" {formatted_allowed_input_names}"
+ )
+
+ return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
+
+
+def check_if_model_is_supported(model: PreTrainedModel):
+ if model.__class__.__name__ not in _SUPPORTED_MODELS:
+ supported_model_names = ", ".join(_SUPPORTED_MODELS)
+ raise NotImplementedError(
+ f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
+ )
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
+ disable_check: bool = False,
) -> GraphModule:
"""
@@ -565,6 +1042,8 @@ def symbolic_trace(
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
+ disable_check (`bool`, *optional*, defaults to `False`):
+ If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
@@ -580,18 +1059,21 @@ def symbolic_trace(
if input_names is None:
input_names = model.dummy_inputs.keys()
- sig = inspect.signature(model.forward)
- concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
+ input_names = list(input_names)
+ concrete_args = get_concrete_args(model, input_names)
- if not isinstance(model, _SUPPORTED_MODELS):
- supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
- raise NotImplementedError(
- f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
- )
+ if not disable_check:
+ check_if_model_is_supported(model)
# Tracing.
tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
+ traced.config = model.config
+ # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
+ # _generate_dummy_input, where the model class is needed.
+ traced.class_for_deserialization = model.__class__
+ traced.device = model.device
+
return traced
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index bea5b3dd47753a..136762a3785863 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -17,6 +17,7 @@
import inspect
from collections import OrderedDict, UserDict
+from collections.abc import MutableMapping
from contextlib import ExitStack
from dataclasses import fields
from enum import Enum
@@ -310,3 +311,17 @@ def find_labels(model_class):
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
+
+
+def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
+ """Flatten a nested dict into a single level dict."""
+
+ def _flatten_dict(d, parent_key="", delimiter="."):
+ for k, v in d.items():
+ key = str(parent_key) + delimiter + str(k) if parent_key else k
+ if v and isinstance(v, MutableMapping):
+ yield from flatten_dict(v, key, delimiter=delimiter).items()
+ else:
+ yield key, v
+
+ return dict(_flatten_dict(d, parent_key, delimiter))
diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py
index 7386fe34f521c7..a4717cf7ea629f 100644
--- a/src/transformers/utils/hub.py
+++ b/src/transformers/utils/hub.py
@@ -38,6 +38,7 @@
from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError
+from requests.models import Response
from transformers.utils.logging import tqdm
from . import __version__, logging
@@ -77,11 +78,11 @@ def is_offline_mode():
and "TRANSFORMERS_CACHE" not in os.environ
):
logger.warning(
- "In Transformers v4.0.0, the default path to cache downloaded models changed from "
- "'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
- "and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
- "'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
- "only see this message once."
+ "In Transformers v4.0.0, the default path to cache downloaded models changed from"
+ " '~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have"
+ " overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
+ " '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
+ " only see this message once."
)
shutil.move(old_default_cache_path, default_cache_path)
@@ -109,6 +110,7 @@ def is_offline_mode():
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
+HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
def is_remote_url(url_or_filename):
@@ -397,20 +399,27 @@ class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
-def _raise_for_status(request):
+def _raise_for_status(response: Response):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
"""
- if "X-Error-Code" in request.headers:
- error_code = request.headers["X-Error-Code"]
+ if "X-Error-Code" in response.headers:
+ error_code = response.headers["X-Error-Code"]
if error_code == "RepoNotFound":
- raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
+ raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
elif error_code == "EntryNotFound":
- raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
+ raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
elif error_code == "RevisionNotFound":
- raise RevisionNotFoundError((f"404 Client Error: Revision Not Found for url: {request.url}"))
+ raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
- request.raise_for_status()
+ if response.status_code == 401:
+ # The repo was not found and the user is not Authenticated
+ raise RepositoryNotFoundError(
+ f"401 Client Error: Repository not found for url: {response.url}. "
+ "If the repo is private, make sure you are authenticated."
+ )
+
+ response.raise_for_status()
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
@@ -1028,3 +1037,41 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
+
+
+def send_example_telemetry(example_name, *example_args, framework="pytorch"):
+ """
+ Sends telemetry that helps tracking the examples use.
+
+ Args:
+ example_name (`str`): The name of the example.
+ *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
+ try to extract the model and dataset name from those. Nothing else is tracked.
+ framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
+ """
+ if is_offline_mode():
+ return
+
+ data = {"example": example_name, "framework": framework}
+ for args in example_args:
+ args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
+ if "model_name_or_path" in args_as_dict:
+ model_name = args_as_dict["model_name_or_path"]
+ # Filter out local paths
+ if not os.path.isdir(model_name):
+ data["model_name"] = args_as_dict["model_name_or_path"]
+ if "dataset_name" in args_as_dict:
+ data["dataset_name"] = args_as_dict["dataset_name"]
+ elif "task_name" in args_as_dict:
+ # Extract script name from the example_name
+ script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
+ script_name = script_name.replace("_no_trainer", "")
+ data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
+
+ headers = {"user-agent": http_user_agent(data)}
+ try:
+ r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
+ r.raise_for_status()
+ except Exception:
+ # We don't want to error in case of connection errors of any kind.
+ pass
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 505ba94e0b193c..53f7515bca5fa8 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -282,25 +282,36 @@ def is_torch_bf16_available():
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
- # 1. the hardware needs to support bf16 (arch >= Ampere)
- # 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
- # 3. CUDA >= 11
+ # 1. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
+ # 2. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU)
+ # 3. if using gpu, CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
-
- if not torch.cuda.is_available() or torch.version.cuda is None:
- return False
- if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
- return False
- if int(torch.version.cuda.split(".")[0]) < 11:
- return False
+ is_torch_gpu_bf16_available = True
+ is_torch_cpu_bf16_available = True
if version.parse(torch.__version__) < version.parse("1.10"):
- return False
- if not hasattr(torch, "autocast"):
- return False
+ is_torch_gpu_bf16_available = False
+ is_torch_cpu_bf16_available = False
+
+ if torch.cuda.is_available() and torch.version.cuda is not None:
+ if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
+ is_torch_gpu_bf16_available = False
+ if int(torch.version.cuda.split(".")[0]) < 11:
+ is_torch_gpu_bf16_available = False
+ if not hasattr(torch.cuda.amp, "autocast"):
+ is_torch_gpu_bf16_available = False
+ else:
+ is_torch_gpu_bf16_available = False
- return True
+ # checking CPU
+ try:
+ # multiple levels of AttributeError depending on the pytorch version so do them all in one check
+ _ = torch.cpu.amp.autocast
+ except AttributeError:
+ is_torch_cpu_bf16_available = False
+
+ return is_torch_cpu_bf16_available or is_torch_gpu_bf16_available
def is_torch_tf32_available():
@@ -325,7 +336,7 @@ def is_torch_tf32_available():
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
- _torch_fx_available = (torch_version.major, torch_version.minor) == (
+ _torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
@@ -376,6 +387,10 @@ def is_torch_tpu_available():
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
+def is_torchdynamo_available():
+ return importlib.util.find_spec("torchdynamo") is not None
+
+
def is_datasets_available():
return _datasets_available
@@ -400,6 +415,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None
+def is_ipex_available():
+ return importlib.util.find_spec("intel_extension_for_pytorch") is not None
+
+
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
@@ -428,6 +447,10 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None
+def is_accelerate_available():
+ return importlib.util.find_spec("accelerate") is not None
+
+
def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None
@@ -452,6 +475,8 @@ def is_in_notebook():
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ:
+ raise ImportError("databricks")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
@@ -725,6 +750,12 @@ def wrapper(*args, **kwargs):
`pip install pyctcdecode`
"""
+# docstyle-ignore
+ACCELERATE_IMPORT_ERROR = """
+{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
+`pip install accelerate`
+"""
+
BACKENDS_MAPPING = OrderedDict(
[
@@ -750,6 +781,7 @@ def wrapper(*args, **kwargs):
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
]
)
@@ -861,8 +893,13 @@ def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
- f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}"
+ f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
+ f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
+
+
+class OptionalDependencyNotAvailable(BaseException):
+ """Internally used error class for signalling an optional dependency was not found."""
diff --git a/src/transformers/utils/model_parallel_utils.py b/src/transformers/utils/model_parallel_utils.py
index abddd6c60faccf..bcbe808013596f 100644
--- a/src/transformers/utils/model_parallel_utils.py
+++ b/src/transformers/utils/model_parallel_utils.py
@@ -32,13 +32,15 @@ def assert_device_map(device_map, num_blocks):
if len(duplicate_blocks) != 0:
raise ValueError(
- "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
- "attention blocks were specified more than once: " + str(duplicate_blocks)
+ "Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device."
+ " These attention blocks were specified more than once: "
+ + str(duplicate_blocks)
)
if len(missing_blocks) != 0:
raise ValueError(
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
- "blocks to a device on the device_map: " + str(missing_blocks)
+ "blocks to a device on the device_map: "
+ + str(missing_blocks)
)
if len(extra_blocks) != 0:
raise ValueError(
diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py
index 0ffbdc8deecff8..f671ad737c3fad 100644
--- a/src/transformers/utils/notebook.py
+++ b/src/transformers/utils/notebook.py
@@ -174,7 +174,10 @@ def update_bar(self, value, comment=None):
elif self.predicted_remaining is None:
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)}"
else:
- self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} < {format_time(self.predicted_remaining)}"
+ self.label = (
+ f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} <"
+ f" {format_time(self.predicted_remaining)}"
+ )
self.label += f", {1/self.average_time_per_item:.2f} it/s"
self.label += "]" if self.comment is None or len(self.comment) == 0 else f", {self.comment}]"
self.display()
diff --git a/src/transformers/utils/sentencepiece_model_pb2.py b/src/transformers/utils/sentencepiece_model_pb2.py
index 5d52b365caab7f..41411cee8cd65b 100644
--- a/src/transformers/utils/sentencepiece_model_pb2.py
+++ b/src/transformers/utils/sentencepiece_model_pb2.py
@@ -32,7 +32,53 @@
syntax="proto2",
serialized_options=b"H\003",
create_key=_descriptor._internal_create_key,
- serialized_pb=b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03',
+ serialized_pb=(
+ b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01'
+ b" \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02"
+ b" \x01(\t\x12\x41\n\nmodel_type\x18\x03"
+ b" \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04"
+ b" \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12"
+ b' \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n'
+ b" \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b"
+ b" \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12"
+ b' \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r'
+ b" \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e"
+ b" \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f"
+ b" \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12"
+ b" \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10"
+ b" \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11"
+ b" \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14"
+ b" \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15"
+ b" \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17"
+ b" \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16"
+ b" \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18"
+ b" \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19"
+ b" \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e"
+ b" \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$"
+ b" \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18"
+ b' \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18"'
+ b" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18)"
+ b" \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+"
+ b" \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18."
+ b" \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30"
+ b" \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87"
+ b" \x12+\n\x1ctrain_extremely_large_corpus\x18\x31"
+ b' \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01'
+ b" \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03"
+ b" \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12"
+ b" \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06"
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01'
+ b' \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01'
+ b" \x01(\t\x12\x10\n\x08\x65xpected\x18\x02"
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01'
+ b" \x03(\x0b\x32'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02"
+ b" \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03"
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04"
+ b" \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05"
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01"
+ b" \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03"
+ b' \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
+ ),
)
diff --git a/src/transformers/utils/versions.py b/src/transformers/utils/versions.py
index 26a160f1fd6eaa..14db9b55e59704 100644
--- a/src/transformers/utils/versions.py
+++ b/src/transformers/utils/versions.py
@@ -77,7 +77,8 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
if not match:
raise ValueError(
- f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
+ f" got {requirement}"
)
pkg, want_full = match[0]
want_range = want_full.split(",") # there could be multiple requirements
@@ -86,7 +87,8 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
if not match:
raise ValueError(
- f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
+ f" but got {requirement}"
)
op, want_ver = match[0]
wanted[op] = want_ver
diff --git a/templates/adding_a_missing_tokenization_test/cookiecutter-template-{{cookiecutter.modelname}}/test_tokenization_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_missing_tokenization_test/cookiecutter-template-{{cookiecutter.modelname}}/test_tokenization_{{cookiecutter.lowercase_modelname}}.py
index 631886f6b2ebbc..36e35c04ed336e 100644
--- a/templates/adding_a_missing_tokenization_test/cookiecutter-template-{{cookiecutter.modelname}}/test_tokenization_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_missing_tokenization_test/cookiecutter-template-{{cookiecutter.modelname}}/test_tokenization_{{cookiecutter.lowercase_modelname}}.py
@@ -26,25 +26,25 @@
{% endif -%}
{% if cookiecutter.has_fast_class == "True" and cookiecutter.slow_tokenizer_use_sentencepiece == "True" -%}
from transformers.testing_utils import require_sentencepiece, require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_sentencepiece
@require_tokenizers
{% elif cookiecutter.slow_tokenizer_use_sentencepiece == "True" -%}
from transformers.testing_utils import require_sentencepiece
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_sentencepiece
{% elif cookiecutter.has_fast_class == "True" -%}
from transformers.testing_utils import require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
{% else -%}
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
{% endif -%}
diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
index 0d9a0c6d32fec7..f07029ec242caa 100755
--- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
@@ -28,6 +28,7 @@
from typing import Optional, List
import datasets
+import torch
from datasets import load_dataset
import transformers
@@ -45,6 +46,7 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
@@ -206,6 +208,10 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_{{cookiecutter.example_shortcut}}", model_args, data_args)
+
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
@@ -518,6 +524,7 @@ def _mp_fn(index):
get_scheduler,
set_seed,
)
+from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
@@ -661,6 +668,10 @@ def parse_args():
def main():
args = parse_args()
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_{{cookiecutter.example_shortcut}", args)
+
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# Make one log on every process with the configuration for debugging.
@@ -871,7 +882,8 @@ def tokenize_function(examples):
model.eval()
for step, batch in enumerate(eval_dataloader):
- outputs = model(**batch)
+ with torch.no_grad():
+ outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=accelerator.gather(predictions),
diff --git a/templates/adding_a_new_model/README.md b/templates/adding_a_new_model/README.md
index 496c4f004be576..4bb6663937ce77 100644
--- a/templates/adding_a_new_model/README.md
+++ b/templates/adding_a_new_model/README.md
@@ -222,7 +222,7 @@ You will also see a doc file and tests for your new models. First you should run
```
make style
-maxke fix-copies
+make fix-copies
```
and then you can start tweaking your model. You should:
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
index afcfeb87eb7789..0d05ee406addff 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/__init__.py
@@ -18,15 +18,23 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...utils import _LazyModule, is_tokenizers_available
+from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
+
+
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_tf_available
+
+
{% endif %}
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_torch_available
+
+
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_flax_available
+
+
{% endif %}
_import_structure = {
@@ -34,12 +42,22 @@
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
}
-if is_tokenizers_available():
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
@@ -54,7 +72,12 @@
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
{% else %}
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
@@ -70,7 +93,12 @@
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
@@ -84,7 +112,12 @@
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
-if is_tf_available():
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
@@ -96,7 +129,12 @@
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"Flax{{cookiecutter.camelcase_modelname}}ForCausalLM",
@@ -109,7 +147,12 @@
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
-if is_flax_available():
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
@@ -125,12 +168,22 @@
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
- if is_tokenizers_available():
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
@@ -145,7 +198,12 @@
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% else %}
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
@@ -159,7 +217,12 @@
{% endif %}
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
@@ -173,7 +236,12 @@
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
- if is_tf_available():
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
@@ -183,7 +251,12 @@
{% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
@@ -196,7 +269,12 @@
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% else %}
- if is_flax_available():
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
index b485a0d27919e2..451dc03f62ed13 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
@@ -24,15 +24,17 @@
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
+from flax.linen import combine_masks, make_causal_mask
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
- FlaxBaseModelOutput,
- FlaxBaseModelOutputWithPooling,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
FlaxCausalLMOutput,
+ FlaxCausalLMOutputWithCrossAttentions,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
@@ -170,9 +172,11 @@ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, dete
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
+ causal: bool = False
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
@@ -195,30 +199,113 @@ def setup(self):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states: Optional[jnp.array] = None,
+ init_cache: bool = False,
deterministic=True,
- output_attentions: bool = False
+ output_attentions: bool = False,
):
- head_dim = self.config.hidden_size // self.config.num_attention_heads
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
- query_states = self.query(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- value_states = self.value(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
- key_states = self.key(hidden_states).reshape(
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
- )
+ # get query proj
+ query_states = self.query(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.key(key_value_states)
+ value_states = self.value(key_value_states)
+ else:
+ # self_attention
+ key_states = self.key(hidden_states)
+ value_states = self.value(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
@@ -278,6 +365,7 @@ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
config: {{cookiecutter.camelcase_modelname}}Config
+ causal: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -289,6 +377,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
+ key_value_states=None,
+ init_cache=False,
deterministic=True,
output_attentions: bool = False,
):
@@ -299,6 +389,8 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ key_value_states=key_value_states,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -362,24 +454,43 @@ def setup(self):
self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype)
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
+ if self.config.add_cross_attention:
+ self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
layer_head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
):
+ # Self Attention
attention_outputs = self.attention(
hidden_states,
attention_mask,
layer_head_mask=layer_head_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=encoder_hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
@@ -387,6 +498,8 @@ def __call__(
if output_attentions:
outputs += (attention_outputs[1],)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attention_outputs[1],)
return outputs
@@ -405,6 +518,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -412,6 +528,7 @@ def __call__(
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# Check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
@@ -429,6 +546,9 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
)
@@ -438,6 +558,9 @@ def __call__(
if output_attentions:
all_attentions += (layer_outputs[1],)
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -446,8 +569,11 @@ def __call__(
if not return_dict:
return tuple(v for v in outputs if v is not None)
- return FlaxBaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
)
@@ -464,6 +590,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
@@ -473,6 +602,9 @@ def __call__(
hidden_states,
attention_mask,
head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -598,6 +730,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
@@ -609,9 +742,26 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
- random_params = self.module.init(
- rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
- )["params"]
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
+ )
+
+ random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
@@ -623,7 +773,29 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
else:
return random_params
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}}
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length))
+ attention_mask = jnp.ones_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}}
def __call__(
self,
input_ids,
@@ -631,12 +803,15 @@ def __call__(
token_type_ids=None,
position_ids=None,
head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ past_key_values: dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -662,19 +837,60 @@ def __call__(
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
- return self.module.apply(
- {"params": params or self.params},
- jnp.array(input_ids, dtype="i4"),
- jnp.array(attention_mask, dtype="i4"),
- jnp.array(token_type_ids, dtype="i4"),
- jnp.array(position_ids, dtype="i4"),
- jnp.array(head_mask, dtype="i4"),
- not train,
- output_attentions,
- output_hidden_states,
- return_dict,
- rngs=rngs,
- )
+ inputs = {"params": params or self.params}
+
+ if self.config.add_cross_attention:
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxBertAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ else:
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ )
+
+ return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}}
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
@@ -691,14 +907,25 @@ def __call__(
self,
input_ids,
attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
+ # make sure `token_type_ids` is correctly initialized when not passed
+ if token_type_ids is None:
+ token_type_ids = jnp.zeros_like(input_ids)
+
+ # make sure `position_ids` is correctly initialized when not passed
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
@@ -707,6 +934,9 @@ def __call__(
attention_mask,
head_mask=head_mask,
deterministic=deterministic,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@@ -720,11 +950,12 @@ def __call__(
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
- return FlaxBaseModelOutputWithPooling(
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
)
add_start_docstrings(
@@ -1137,6 +1368,112 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiec
FlaxQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
+
+
+class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
+ config: {{cookiecutter.camelcase_modelname}}Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
+ self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.{{cookiecutter.lowercase_modelname}}(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ else:
+ shared_embedding = None
+
+ # Compute the prediction scores
+ logits = self.cls(hidden_states, shared_embedding=shared_embedding)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ {{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
+ autoregressive tasks.
+ """,
+ {{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
+)
+
+class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
+ module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
{# encoder_decoder #}
{% else %}
import math
@@ -1353,7 +1690,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
-
+
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
@@ -1659,7 +1996,7 @@ def setup(self) -> None:
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
- self.config.encoder_ffn_dim,
+ self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
@@ -2660,10 +2997,10 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
```python
>>> import jax
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
-
+
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
-
+
>>> TXT = "My friends are but they eat too many carbs."
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
index bde5eaa2e3b95f..7d09a77b70ec48 100755
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
@@ -876,7 +876,7 @@ def forward(
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -2100,7 +2100,7 @@ def _set_gradient_checkpointing(self, module, value=False):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
also be used by default.
- If you want to change padding behavior, you should read [`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs`] and
+ If you want to change padding behavior, you should read [`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_attention_mask`] and
modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_flax_{{cookiecutter.lowercase_modelname}}.py
index 69b0a7fae20129..37b22a75c3e970 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_flax_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_flax_{{cookiecutter.lowercase_modelname}}.py
@@ -20,8 +20,8 @@
from transformers import is_flax_available, {{cookiecutter.camelcase_modelname}}Config
from transformers.testing_utils import require_flax, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import numpy as np
@@ -345,8 +345,8 @@ def test_inference_masked_lm(self):
)
from transformers.testing_utils import require_sentencepiece, require_flax, require_tokenizers, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py
index 0f4d7824c16420..48cd1239eaed1d 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py
@@ -20,8 +20,8 @@
from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Config
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
@@ -711,8 +711,8 @@ def test_inference_masked_lm(self):
)
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
index e15adc91e30c15..7becb51551832b 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
@@ -18,13 +18,13 @@
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
import unittest
-from ..test_modeling_common import floats_tensor
+from ...test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers import {{cookiecutter.camelcase_modelname}}Config
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -489,9 +489,9 @@ def test_inference_masked_lm(self):
from transformers.utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
index c95b82115dc36e..273adca0ef230e 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py
@@ -115,7 +115,7 @@
{% endif -%}
# End.
-# Below: " # Fast tokenizers"
+# Below: " # Fast tokenizers structure"
# Replace with:
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
# End.
@@ -126,7 +126,7 @@
# End.
# To replace in: "src/transformers/__init__.py"
-# Below: " if is_torch_available():" if generating PyTorch
+# Below: " # PyTorch model imports" if generating PyTorch
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -155,7 +155,7 @@
{% endif -%}
# End.
-# Below: " if is_tf_available():" if generating TensorFlow
+# Below: " # TensorFlow model imports" if generating TensorFlow
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -179,7 +179,7 @@
{% endif -%}
# End.
-# Below: " if is_flax_available():" if generating Flax
+# Below: " # Flax model imports" if generating Flax
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import (
@@ -204,7 +204,7 @@
{% endif -%}
# End.
-# Below: " if is_tokenizers_available():"
+# Below: " # Fast tokenizers imports"
# Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
# End.
diff --git a/test_results.txt b/test_results.txt
new file mode 100644
index 00000000000000..b061a2f1b9857d
--- /dev/null
+++ b/test_results.txt
@@ -0,0 +1,9 @@
+background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
+
+
+background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
+
+
+background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
+
+
diff --git a/tests/bart/test_modeling_bart.py b/tests/bart/test_modeling_bart.py
deleted file mode 100644
index db443164c58e5b..00000000000000
--- a/tests/bart/test_modeling_bart.py
+++ /dev/null
@@ -1,968 +0,0 @@
-# coding=utf-8
-# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" Testing suite for the PyTorch BART model. """
-
-
-import copy
-import tempfile
-import unittest
-
-import timeout_decorator # noqa
-
-from transformers import BartConfig, is_torch_available
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from transformers.utils import cached_property
-
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
-
-
-if is_torch_available():
- import torch
-
- from transformers import (
- AutoModelForSequenceClassification,
- BartForCausalLM,
- BartForConditionalGeneration,
- BartForQuestionAnswering,
- BartForSequenceClassification,
- BartModel,
- BartTokenizer,
- pipeline,
- )
- from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
-
-
-def prepare_bart_inputs_dict(
- config,
- input_ids,
- decoder_input_ids=None,
- attention_mask=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- cross_attn_head_mask=None,
-):
- if attention_mask is None:
- attention_mask = input_ids.ne(config.pad_token_id)
- if decoder_attention_mask is None:
- decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
- if head_mask is None:
- head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
- if decoder_head_mask is None:
- decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
- if cross_attn_head_mask is None:
- cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
- return {
- "input_ids": input_ids,
- "decoder_input_ids": decoder_input_ids,
- "attention_mask": attention_mask,
- "decoder_attention_mask": attention_mask,
- "head_mask": head_mask,
- "decoder_head_mask": decoder_head_mask,
- "cross_attn_head_mask": cross_attn_head_mask,
- }
-
-
-class BartModelTester:
- def __init__(
- self,
- parent,
- batch_size=13,
- seq_length=7,
- is_training=True,
- use_labels=False,
- vocab_size=99,
- hidden_size=16,
- num_hidden_layers=2,
- num_attention_heads=4,
- intermediate_size=4,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=20,
- eos_token_id=2,
- pad_token_id=1,
- bos_token_id=0,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.seq_length = seq_length
- self.is_training = is_training
- self.use_labels = use_labels
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_act = hidden_act
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.eos_token_id = eos_token_id
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
-
- def prepare_config_and_inputs(self):
- input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
- input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
- 3,
- )
- input_ids[:, -1] = self.eos_token_id # Eos Token
-
- decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
-
- config = self.get_config()
- inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
- return config, inputs_dict
-
- def get_config(self):
- return BartConfig(
- vocab_size=self.vocab_size,
- d_model=self.hidden_size,
- encoder_layers=self.num_hidden_layers,
- decoder_layers=self.num_hidden_layers,
- encoder_attention_heads=self.num_attention_heads,
- decoder_attention_heads=self.num_attention_heads,
- encoder_ffn_dim=self.intermediate_size,
- decoder_ffn_dim=self.intermediate_size,
- dropout=self.hidden_dropout_prob,
- attention_dropout=self.attention_probs_dropout_prob,
- max_position_embeddings=self.max_position_embeddings,
- eos_token_id=self.eos_token_id,
- bos_token_id=self.bos_token_id,
- pad_token_id=self.pad_token_id,
- )
-
- def get_pipeline_config(self):
- config = self.get_config()
- config.max_position_embeddings = 100
- return config
-
- def prepare_config_and_inputs_for_common(self):
- config, inputs_dict = self.prepare_config_and_inputs()
- return config, inputs_dict
-
- def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
- model = BartModel(config=config).get_decoder().to(torch_device).eval()
- input_ids = inputs_dict["input_ids"]
- attention_mask = inputs_dict["attention_mask"]
- head_mask = inputs_dict["head_mask"]
-
- # first forward pass
- outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
-
- output, past_key_values = outputs.to_tuple()
-
- # create hypothetical multiple next token and extent to next_input_ids
- next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
- next_attn_mask = ids_tensor((self.batch_size, 3), 2)
-
- # append to next input_ids and
- next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
- next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
-
- output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
- output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
- "last_hidden_state"
- ]
-
- # select random slice
- random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
- output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
- output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
-
- self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
-
- # test that outputs are equal for slice
- self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
-
- def check_encoder_decoder_model_standalone(self, config, inputs_dict):
- model = BartModel(config=config).to(torch_device).eval()
- outputs = model(**inputs_dict)
-
- encoder_last_hidden_state = outputs.encoder_last_hidden_state
- last_hidden_state = outputs.last_hidden_state
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- encoder = model.get_encoder()
- encoder.save_pretrained(tmpdirname)
- encoder = BartEncoder.from_pretrained(tmpdirname).to(torch_device)
-
- encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
- 0
- ]
-
- self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- decoder = model.get_decoder()
- decoder.save_pretrained(tmpdirname)
- decoder = BartDecoder.from_pretrained(tmpdirname).to(torch_device)
-
- last_hidden_state_2 = decoder(
- input_ids=inputs_dict["decoder_input_ids"],
- attention_mask=inputs_dict["decoder_attention_mask"],
- encoder_hidden_states=encoder_last_hidden_state,
- encoder_attention_mask=inputs_dict["attention_mask"],
- )[0]
-
- self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
-
-
-@require_torch
-class BartHeadTests(unittest.TestCase):
- vocab_size = 99
-
- def _get_config_and_data(self):
- input_ids = torch.tensor(
- [
- [71, 82, 18, 33, 46, 91, 2],
- [68, 34, 26, 58, 30, 82, 2],
- [5, 97, 17, 39, 94, 40, 2],
- [76, 83, 94, 25, 70, 78, 2],
- [87, 59, 41, 35, 48, 66, 2],
- [55, 13, 16, 58, 5, 2, 1], # note padding
- [64, 27, 31, 51, 12, 75, 2],
- [52, 64, 86, 17, 83, 39, 2],
- [48, 61, 9, 24, 71, 82, 2],
- [26, 1, 60, 48, 22, 13, 2],
- [21, 5, 62, 28, 14, 76, 2],
- [45, 98, 37, 86, 59, 48, 2],
- [70, 70, 50, 9, 28, 0, 2],
- ],
- dtype=torch.long,
- device=torch_device,
- )
-
- batch_size = input_ids.shape[0]
- config = BartConfig(
- vocab_size=self.vocab_size,
- d_model=24,
- encoder_layers=2,
- decoder_layers=2,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=32,
- decoder_ffn_dim=32,
- max_position_embeddings=48,
- eos_token_id=2,
- pad_token_id=1,
- bos_token_id=0,
- )
- return config, input_ids, batch_size
-
- def test_sequence_classification_forward(self):
- config, input_ids, batch_size = self._get_config_and_data()
- labels = _long_tensor([2] * batch_size).to(torch_device)
- model = BartForSequenceClassification(config)
- model.to(torch_device)
- outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
- expected_shape = torch.Size((batch_size, config.num_labels))
- self.assertEqual(outputs["logits"].shape, expected_shape)
- self.assertIsInstance(outputs["loss"].item(), float)
-
- def test_question_answering_forward(self):
- config, input_ids, batch_size = self._get_config_and_data()
- sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
- model = BartForQuestionAnswering(config)
- model.to(torch_device)
- outputs = model(
- input_ids=input_ids,
- start_positions=sequence_labels,
- end_positions=sequence_labels,
- )
-
- self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
- self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
- self.assertIsInstance(outputs["loss"].item(), float)
-
- @timeout_decorator.timeout(1)
- def test_lm_forward(self):
- config, input_ids, batch_size = self._get_config_and_data()
- lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
- lm_model = BartForConditionalGeneration(config)
- lm_model.to(torch_device)
- outputs = lm_model(input_ids=input_ids, labels=lm_labels)
- expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
- self.assertEqual(outputs["logits"].shape, expected_shape)
- self.assertIsInstance(outputs["loss"].item(), float)
-
- def test_lm_uneven_forward(self):
- config = BartConfig(
- vocab_size=self.vocab_size,
- d_model=14,
- encoder_layers=2,
- decoder_layers=2,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=8,
- decoder_ffn_dim=8,
- max_position_embeddings=48,
- )
- lm_model = BartForConditionalGeneration(config).to(torch_device)
- context = torch.tensor(
- [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long
- )
- summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long)
- outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
- expected_shape = (*summary.shape, config.vocab_size)
- self.assertEqual(outputs["logits"].shape, expected_shape)
-
- def test_generate_beam_search(self):
- input_ids = torch.tensor([[71, 82, 2], [68, 34, 2]], device=torch_device, dtype=torch.long)
- config = BartConfig(
- vocab_size=self.vocab_size,
- d_model=24,
- encoder_layers=2,
- decoder_layers=2,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=32,
- decoder_ffn_dim=32,
- max_position_embeddings=48,
- eos_token_id=2,
- pad_token_id=1,
- bos_token_id=0,
- )
- lm_model = BartForConditionalGeneration(config).to(torch_device)
- lm_model.eval()
-
- max_length = 5
- generated_ids = lm_model.generate(
- input_ids.clone(),
- do_sample=True,
- num_return_sequences=1,
- num_beams=2,
- no_repeat_ngram_size=3,
- max_length=max_length,
- )
- self.assertEqual(generated_ids.shape, (input_ids.shape[0], max_length))
-
- def test_shift_tokens_right(self):
- input_ids = torch.tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=torch.long)
- shifted = shift_tokens_right(input_ids, 1, 2)
- n_pad_before = input_ids.eq(1).float().sum()
- n_pad_after = shifted.eq(1).float().sum()
- self.assertEqual(shifted.shape, input_ids.shape)
- self.assertEqual(n_pad_after, n_pad_before - 1)
- self.assertTrue(torch.eq(shifted[:, 0], 2).all())
-
- @slow
- def test_tokenization(self):
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
- examples = [" Hello world", " DomDramg"] # need leading spaces for equality
- fairseq_results = [
- torch.tensor([0, 20920, 232, 2]),
- torch.tensor([0, 11349, 495, 4040, 571, 2]),
- ]
- for ex, desired_result in zip(examples, fairseq_results):
- bart_toks = tokenizer.encode(ex, return_tensors="pt").squeeze()
- assert_tensors_close(desired_result.long(), bart_toks, prefix=ex)
-
- def test_generate_fp16(self):
- config, input_ids, batch_size = self._get_config_and_data()
- attention_mask = input_ids.ne(1).to(torch_device)
- model = BartForConditionalGeneration(config).eval().to(torch_device)
- if torch_device == "cuda":
- model.half()
- model.generate(input_ids, attention_mask=attention_mask)
- model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
-
- def test_dummy_inputs(self):
- config, *_ = self._get_config_and_data()
- model = BartForConditionalGeneration(config).eval().to(torch_device)
- model(**model.dummy_inputs)
-
- def test_resize_tokens_embeddings_more(self):
- config, input_ids, _ = self._get_config_and_data()
-
- def _get_embs(m):
- return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
-
- model = BartForConditionalGeneration(config).eval().to(torch_device)
- input, output = _get_embs(model)
- self.assertTrue(torch.eq(input, output).all())
- new_vocab_size = 45
- model.resize_token_embeddings(new_vocab_size)
- input_new, output_new = _get_embs(model)
- self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
- self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
- self.assertTrue(torch.eq(input_new, output_new).all())
-
-
-@require_torch
-class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
- all_model_classes = (
- (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
- if is_torch_available()
- else ()
- )
- all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
- is_encoder_decoder = True
- test_pruning = False
- test_missing_keys = False
-
- def setUp(self):
- self.model_tester = BartModelTester(self)
- self.config_tester = ConfigTester(self, config_class=BartConfig)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_save_load_strict(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs()
- for model_class in self.all_model_classes:
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
- self.assertEqual(info["missing_keys"], [])
-
- def test_decoder_model_past_with_large_inputs(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
-
- def test_encoder_decoder_model_standalone(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
- self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
-
- # BartForSequenceClassification does not support inputs_embeds
- def test_inputs_embeds(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in (BartModel, BartForConditionalGeneration, BartForQuestionAnswering):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
-
- if not self.is_encoder_decoder:
- input_ids = inputs["input_ids"]
- del inputs["input_ids"]
- else:
- encoder_input_ids = inputs["input_ids"]
- decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
- del inputs["input_ids"]
- inputs.pop("decoder_input_ids", None)
-
- wte = model.get_input_embeddings()
- if not self.is_encoder_decoder:
- inputs["inputs_embeds"] = wte(input_ids)
- else:
- inputs["inputs_embeds"] = wte(encoder_input_ids)
- inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
-
- with torch.no_grad():
- model(**inputs)[0]
-
- def test_generate_fp16(self):
- config, input_dict = self.model_tester.prepare_config_and_inputs()
- input_ids = input_dict["input_ids"]
- attention_mask = input_ids.ne(1).to(torch_device)
- model = BartForConditionalGeneration(config).eval().to(torch_device)
- if torch_device == "cuda":
- model.half()
- model.generate(input_ids, attention_mask=attention_mask)
- model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
-
-
-def assert_tensors_close(a, b, atol=1e-12, prefix=""):
- """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
- if a is None and b is None:
- return True
- try:
- if torch.allclose(a, b, atol=atol):
- return True
- raise
- except Exception:
- pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
- if a.numel() > 100:
- msg = f"tensor values are {pct_different:.1%} percent different."
- else:
- msg = f"{a} != {b}"
- if prefix:
- msg = prefix + ": " + msg
- raise AssertionError(msg)
-
-
-def _long_tensor(tok_lst):
- return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
-
-
-@require_torch
-@slow
-class FastIntegrationTests(unittest.TestCase):
- """These tests are useful for debugging since they operate on a model with 1 encoder layer and 1 decoder layer."""
-
- @cached_property
- def tok(self):
- return BartTokenizer.from_pretrained("facebook/bart-large")
-
- @cached_property
- def xsum_1_1_model(self):
- return BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-1-1")
-
- def test_xsum_1_1_generation(self):
- hf = self.xsum_1_1_model
- tok = self.tok
- ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
- EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
-
- dct = tok(ARTICLE, return_tensors="pt")
- generated_ids = hf.generate(**dct, num_beams=4)
- result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
- assert EXPECTED == result
-
- def test_xsum_1_1_batch_generation(self):
- # test batch
-
- batch = self.tok(
- [
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
- ],
- return_tensors="pt",
- padding="longest",
- truncation=True,
- )
- generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
- result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
- assert (
- result[0]
- == " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
- )
- assert (
- result[1]
- == " An investigation into the crash that killed at least 10 people in the French capital has been released by the French police investigating the crash."
- )
-
- def test_encoder_equiv(self):
- # test batch
-
- batch = self.tok(
- [
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
- ],
- return_tensors="pt",
- padding="longest",
- truncation=True,
- )
- features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
- expected = [[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]
- assert_tensors_close(features[0, :3, :3], torch.tensor(expected), atol=1e-3)
-
-
-@require_torch
-@require_sentencepiece
-@require_tokenizers
-class BartModelIntegrationTests(unittest.TestCase):
- @cached_property
- def default_tokenizer(self):
- return BartTokenizer.from_pretrained("facebook/bart-large")
-
- @slow
- def test_inference_no_head(self):
- model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
- input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
- attention_mask = input_ids.ne(model.config.pad_token_id)
- with torch.no_grad():
- output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
- expected_shape = torch.Size((1, 11, 1024))
- self.assertEqual(output.shape, expected_shape)
- expected_slice = torch.tensor(
- [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
- )
- self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
-
- @slow
- def test_base_mask_filling(self):
- pbase = pipeline(task="fill-mask", model="facebook/bart-base")
- src_text = [" I went to the ."]
- results = [x["token_str"] for x in pbase(src_text)]
- assert " bathroom" in results
-
- @slow
- def test_large_mask_filling(self):
- plarge = pipeline(task="fill-mask", model="facebook/bart-large")
- src_text = [" I went to the ."]
- results = [x["token_str"] for x in plarge(src_text)]
- expected_results = [" bathroom", " gym", " wrong", " movies", " hospital"]
- self.assertListEqual(results, expected_results)
-
- @slow
- def test_mnli_inference(self):
- example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
- input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b])
-
- model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli").to(
- torch_device
- ) # eval called in from_pre
- attention_mask = input_ids.ne(model.config.pad_token_id)
- # Test that model hasn't changed
- with torch.no_grad():
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
-
- batched_logits = outputs.logits
- expected_shape = torch.Size((2, 3))
- self.assertEqual(batched_logits.shape, expected_shape)
- expected_slice = torch.tensor([[0.1907, 1.4342, -1.0289]], device=torch_device)
- logits_arr = batched_logits[0].detach()
-
- # Test that padding does not change results
- input_ids_no_pad = _long_tensor([example_b[:-1]])
- attention_mask_no_pad = input_ids_no_pad.ne(model.config.pad_token_id)
-
- with torch.no_grad():
- logits2 = model(input_ids=input_ids_no_pad, attention_mask=attention_mask_no_pad).logits.squeeze()
- assert_tensors_close(batched_logits[1], logits2, atol=1e-3)
- assert_tensors_close(expected_slice, logits_arr, atol=1e-3)
-
- @slow
- def test_xsum_summarization_same_as_fairseq(self):
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
- tok = self.default_tokenizer
-
- PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
-
- EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
- dct = tok.batch_encode_plus(
- [PGE_ARTICLE],
- max_length=1024,
- padding="max_length",
- truncation=True,
- return_tensors="pt",
- ).to(torch_device)
-
- hypotheses_batch = model.generate(
- input_ids=dct["input_ids"],
- attention_mask=dct["attention_mask"],
- num_beams=2,
- max_length=62,
- min_length=11,
- length_penalty=1.0,
- no_repeat_ngram_size=3,
- early_stopping=True,
- decoder_start_token_id=model.config.eos_token_id,
- )
-
- decoded = tok.batch_decode(
- hypotheses_batch,
- skip_special_tokens=True,
- )
- self.assertEqual(EXPECTED_SUMMARY, decoded[0])
-
- def test_xsum_config_generation_params(self):
- config = BartConfig.from_pretrained("facebook/bart-large-xsum")
- expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
- config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
- self.assertDictEqual(expected_params, config_params)
-
- @slow
- def test_cnn_summarization_same_as_fairseq(self):
- hf = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
- tok = BartTokenizer.from_pretrained("facebook/bart-large")
-
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noq
-
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
-
- # The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
-
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
-
- dct = tok.batch_encode_plus(
- [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
- max_length=1024,
- padding="max_length",
- truncation_strategy="only_first",
- truncation=True,
- return_tensors="pt",
- )
-
- self.assertEqual(1024, dct["input_ids"].shape[1])
- hypotheses_batch = hf.generate(
- input_ids=dct["input_ids"].to(torch_device),
- attention_mask=dct["attention_mask"].to(torch_device),
- num_beams=2,
- )
- assert hypotheses_batch[:, 1].eq(0).all().item()
-
- EXPECTED = [
- "A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
- "magazines claim to have found a cell phone video showing the crash. The publications say they watched "
- "the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
- "9525 were killed.",
- "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
- "jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
- "Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
- "move toward greater justice.",
- "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
- "debate that has already begun will likely result in more heat than light. He says critics have made "
- "dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
- "nuclear weapon.",
- "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
- "say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
- "Bronx on Friday. If convicted, she faces up to four years in prison.",
- ]
-
- generated_summaries = tok.batch_decode(
- hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
- )
- assert generated_summaries == EXPECTED
-
-
-class BartStandaloneDecoderModelTester:
- def __init__(
- self,
- parent,
- vocab_size=99,
- batch_size=13,
- d_model=16,
- decoder_seq_length=7,
- is_training=True,
- is_decoder=True,
- use_attention_mask=True,
- use_cache=False,
- use_labels=True,
- decoder_start_token_id=2,
- decoder_ffn_dim=32,
- decoder_layers=4,
- encoder_attention_heads=4,
- decoder_attention_heads=4,
- max_position_embeddings=30,
- is_encoder_decoder=False,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- scope=None,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.decoder_seq_length = decoder_seq_length
- # For common tests
- self.seq_length = self.decoder_seq_length
- self.is_training = is_training
- self.use_attention_mask = use_attention_mask
- self.use_labels = use_labels
-
- self.vocab_size = vocab_size
- self.d_model = d_model
- self.hidden_size = d_model
- self.num_hidden_layers = decoder_layers
- self.decoder_layers = decoder_layers
- self.decoder_ffn_dim = decoder_ffn_dim
- self.encoder_attention_heads = encoder_attention_heads
- self.decoder_attention_heads = decoder_attention_heads
- self.num_attention_heads = decoder_attention_heads
- self.eos_token_id = eos_token_id
- self.bos_token_id = bos_token_id
- self.pad_token_id = pad_token_id
- self.decoder_start_token_id = decoder_start_token_id
- self.use_cache = use_cache
- self.max_position_embeddings = max_position_embeddings
- self.is_encoder_decoder = is_encoder_decoder
-
- self.scope = None
- self.decoder_key_length = decoder_seq_length
- self.base_model_out_len = 2
- self.decoder_attention_idx = 1
-
- def prepare_config_and_inputs(self):
- input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
-
- attention_mask = None
- if self.use_attention_mask:
- attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
-
- lm_labels = None
- if self.use_labels:
- lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
-
- config = BartConfig(
- vocab_size=self.vocab_size,
- d_model=self.d_model,
- encoder_layers=self.decoder_layers,
- decoder_layers=self.decoder_layers,
- decoder_ffn_dim=self.decoder_ffn_dim,
- encoder_attention_heads=self.encoder_attention_heads,
- decoder_attention_heads=self.decoder_attention_heads,
- eos_token_id=self.eos_token_id,
- bos_token_id=self.bos_token_id,
- use_cache=self.use_cache,
- pad_token_id=self.pad_token_id,
- decoder_start_token_id=self.decoder_start_token_id,
- max_position_embeddings=self.max_position_embeddings,
- is_encoder_decoder=self.is_encoder_decoder,
- )
-
- return (
- config,
- input_ids,
- attention_mask,
- lm_labels,
- )
-
- def prepare_config_and_inputs_for_decoder(self):
- (
- config,
- input_ids,
- attention_mask,
- lm_labels,
- ) = self.prepare_config_and_inputs()
-
- encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
- encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
-
- return (
- config,
- input_ids,
- attention_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- lm_labels,
- )
-
- def create_and_check_decoder_model_past(
- self,
- config,
- input_ids,
- attention_mask,
- lm_labels,
- ):
- config.use_cache = True
- model = BartDecoder(config=config).to(torch_device).eval()
- # first forward pass
- outputs = model(input_ids, use_cache=True)
- outputs_use_cache_conf = model(input_ids)
- outputs_no_past = model(input_ids, use_cache=False)
-
- self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
- self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
-
- past_key_values = outputs["past_key_values"]
-
- # create hypothetical next token and extent to next_input_ids
- next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
-
- # append to next input_ids and
- next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
-
- output_from_no_past = model(next_input_ids)["last_hidden_state"]
- output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
-
- # select random slice
- random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
- output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
- output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
-
- # test that outputs are equal for slice
- assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
-
- def create_and_check_decoder_model_attention_mask_past(
- self,
- config,
- input_ids,
- attention_mask,
- lm_labels,
- ):
- model = BartDecoder(config=config).to(torch_device).eval()
-
- # create attention mask
- attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
-
- half_seq_length = input_ids.shape[-1] // 2
- attn_mask[:, half_seq_length:] = 0
-
- # first forward pass
- past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
-
- # create hypothetical next token and extent to next_input_ids
- next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
-
- # change a random masked slice from input_ids
- random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
- random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
- input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
-
- # append to next input_ids and attn_mask
- next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
- attn_mask = torch.cat(
- [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
- dim=1,
- )
-
- # get two different outputs
- output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
- output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
- "last_hidden_state"
- ]
-
- # select random slice
- random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
- output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
- output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
-
- # test that outputs are equal for slice
- assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- (
- config,
- input_ids,
- attention_mask,
- lm_labels,
- ) = config_and_inputs
-
- inputs_dict = {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- }
- return config, inputs_dict
-
-
-@require_torch
-class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
- all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
- all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
- test_pruning = False
- is_encoder_decoder = False
-
- def setUp(
- self,
- ):
- self.model_tester = BartStandaloneDecoderModelTester(self, is_training=False)
- self.config_tester = ConfigTester(self, config_class=BartConfig)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_decoder_model_past(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
-
- def test_decoder_model_attn_mask_past(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
-
- def test_retain_grad_hidden_states_attentions(self):
- # decoder cannot keep gradients
- return
diff --git a/tests/bart/test_modeling_tf_bart.py b/tests/bart/test_modeling_tf_bart.py
deleted file mode 100644
index 070d16e56b1461..00000000000000
--- a/tests/bart/test_modeling_tf_bart.py
+++ /dev/null
@@ -1,477 +0,0 @@
-# coding=utf-8
-# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-
-from transformers import BartConfig, BartTokenizer, is_tf_available
-from transformers.testing_utils import require_tf, slow
-from transformers.utils import cached_property
-
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
-from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
-
-
-if is_tf_available():
- import tensorflow as tf
-
- from transformers import TFBartForConditionalGeneration, TFBartModel
-
-
-@require_tf
-class TFBartModelTester:
- config_cls = BartConfig
- config_updates = {}
- hidden_act = "gelu"
-
- def __init__(
- self,
- parent,
- batch_size=13,
- seq_length=7,
- is_training=True,
- use_labels=False,
- vocab_size=99,
- hidden_size=32,
- num_hidden_layers=5,
- num_attention_heads=4,
- intermediate_size=37,
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=20,
- eos_token_id=2,
- pad_token_id=1,
- bos_token_id=0,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.seq_length = seq_length
- self.is_training = is_training
- self.use_labels = use_labels
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
-
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.eos_token_id = eos_token_id
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
-
- def prepare_config_and_inputs_for_common(self):
- input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
- eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
- input_ids = tf.concat([input_ids, eos_tensor], axis=1)
-
- decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
-
- config = self.config_cls(
- vocab_size=self.vocab_size,
- d_model=self.hidden_size,
- encoder_layers=self.num_hidden_layers,
- decoder_layers=self.num_hidden_layers,
- encoder_attention_heads=self.num_attention_heads,
- decoder_attention_heads=self.num_attention_heads,
- encoder_ffn_dim=self.intermediate_size,
- decoder_ffn_dim=self.intermediate_size,
- dropout=self.hidden_dropout_prob,
- attention_dropout=self.attention_probs_dropout_prob,
- max_position_embeddings=self.max_position_embeddings,
- eos_token_ids=[2],
- bos_token_id=self.bos_token_id,
- pad_token_id=self.pad_token_id,
- decoder_start_token_id=self.pad_token_id,
- **self.config_updates,
- )
- inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
- return config, inputs_dict
-
- def check_decoder_model_past_large_inputs(self, config, inputs_dict):
- model = TFBartModel(config=config).get_decoder()
- input_ids = inputs_dict["input_ids"]
-
- input_ids = input_ids[:1, :]
- attention_mask = inputs_dict["attention_mask"][:1, :]
- head_mask = inputs_dict["head_mask"]
- self.batch_size = 1
-
- # first forward pass
- outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
-
- output, past_key_values = outputs.to_tuple()
-
- # create hypothetical next token and extent to next_input_ids
- next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
- next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
-
- # append to next input_ids and
- next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
- next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
-
- output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
- output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
-
- self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
-
- # select random slice
- random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
- output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
- output_from_past_slice = output_from_past[:, :, random_slice_idx]
-
- # test that outputs are equal for slice
- tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
-
-
-def prepare_bart_inputs_dict(
- config,
- input_ids,
- decoder_input_ids,
- attention_mask=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- cross_attn_head_mask=None,
-):
- if attention_mask is None:
- attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
- if decoder_attention_mask is None:
- decoder_attention_mask = tf.concat(
- [
- tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
- tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
- ],
- axis=-1,
- )
- if head_mask is None:
- head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
- if decoder_head_mask is None:
- decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
- if cross_attn_head_mask is None:
- cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
- return {
- "input_ids": input_ids,
- "decoder_input_ids": decoder_input_ids,
- "attention_mask": attention_mask,
- "decoder_attention_mask": decoder_attention_mask,
- "head_mask": head_mask,
- "decoder_head_mask": decoder_head_mask,
- "cross_attn_head_mask": cross_attn_head_mask,
- }
-
-
-@require_tf
-class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
- all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
- all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
- is_encoder_decoder = True
- test_pruning = False
- test_onnx = True
- onnx_min_opset = 10
-
- def setUp(self):
- self.model_tester = TFBartModelTester(self)
- self.config_tester = ConfigTester(self, config_class=BartConfig)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_decoder_model_past_large_inputs(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
- self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
-
- def test_model_common_attributes(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
- assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
-
- if model_class in self.all_generative_model_classes:
- x = model.get_output_embeddings()
- assert isinstance(x, tf.keras.layers.Layer)
- name = model.get_bias()
- assert isinstance(name, dict)
- for k, v in name.items():
- assert isinstance(v, tf.Variable)
- else:
- x = model.get_output_embeddings()
- assert x is None
- name = model.get_bias()
- assert name is None
-
- def test_resize_token_embeddings(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- def _get_word_embedding_weight(model, embedding_layer):
- if hasattr(embedding_layer, "weight"):
- return embedding_layer.weight
- else:
- # Here we build the word embeddings weights if not exists.
- # And then we retry to get the attribute once built.
- model(model.dummy_inputs)
- if hasattr(embedding_layer, "weight"):
- return embedding_layer.weight
- else:
- return None
-
- for model_class in self.all_model_classes:
- for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
- # build the embeddings
- model = model_class(config=config)
- old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
- old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
- old_final_logits_bias = model.get_bias()
-
- # reshape the embeddings
- model.resize_token_embeddings(size)
- new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
- new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
- new_final_logits_bias = model.get_bias()
-
- # check that the resized embeddings size matches the desired size.
- assert_size = size if size is not None else config.vocab_size
-
- self.assertEqual(new_input_embeddings.shape[0], assert_size)
-
- # check that weights remain the same after resizing
- models_equal = True
- for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
- if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
- models_equal = False
- self.assertTrue(models_equal)
-
- if old_output_embeddings is not None and new_output_embeddings is not None:
- self.assertEqual(new_output_embeddings.shape[0], assert_size)
-
- models_equal = True
- for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
- if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
- models_equal = False
- self.assertTrue(models_equal)
-
- if old_final_logits_bias is not None and new_final_logits_bias is not None:
- old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
- new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
- self.assertEqual(new_final_logits_bias.shape[0], 1)
- self.assertEqual(new_final_logits_bias.shape[1], assert_size)
-
- models_equal = True
- for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
- for p1, p2 in zip(old, new):
- if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
- models_equal = False
- self.assertTrue(models_equal)
-
- def test_saved_model_creation(self):
- # This test is too long (>30sec) and makes fail the CI
- pass
-
-
-def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
- """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
- if a is None and b is None:
- return True
- try:
- if tf.debugging.assert_near(a, b, atol=atol):
- return True
- raise
- except Exception:
- if len(prefix) > 0:
- prefix = f"{prefix}: "
- raise AssertionError(f"{prefix}{a} != {b}")
-
-
-def _long_tensor(tok_lst):
- return tf.constant(tok_lst, dtype=tf.int32)
-
-
-@require_tf
-class TFBartHeadTests(unittest.TestCase):
- vocab_size = 99
-
- def _get_config_and_data(self):
- eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
- input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
- batch_size = input_ids.shape[0]
- config = BartConfig(
- vocab_size=self.vocab_size,
- d_model=24,
- encoder_layers=2,
- decoder_layers=2,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=32,
- decoder_ffn_dim=32,
- max_position_embeddings=48,
- eos_token_id=2,
- pad_token_id=1,
- bos_token_id=0,
- decoder_start_token_id=2,
- )
- return config, input_ids, batch_size
-
- def test_lm_forward(self):
- config, input_ids, batch_size = self._get_config_and_data()
- decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
- lm_model = TFBartForConditionalGeneration(config)
- outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
- expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
- self.assertEqual(outputs.logits.shape, expected_shape)
-
- def test_lm_uneven_forward(self):
- config = BartConfig(
- vocab_size=10,
- d_model=24,
- encoder_layers=2,
- decoder_layers=2,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=32,
- decoder_ffn_dim=32,
- max_position_embeddings=48,
- )
- lm_model = TFBartForConditionalGeneration(config)
- context = tf.fill((7, 2), 4)
- summary = tf.fill((7, 7), 6)
- outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
- expected_shape = (*summary.shape, config.vocab_size)
- self.assertEqual(outputs.logits.shape, expected_shape)
-
-
-@slow
-@require_tf
-class TFBartModelIntegrationTest(unittest.TestCase):
- def test_inference_no_head(self):
- model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large").model
-
- input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
- attention_mask = tf.cast(tf.math.not_equal(input_ids, model.config.pad_token_id), tf.int8)
- output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
- expected_shape = (1, 11, 1024)
- self.assertEqual(output.shape, expected_shape)
- expected_slice = tf.convert_to_tensor(
- [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
- )
- tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-3)
-
- def test_cnn_summarization_same_as_fairseq_hard(self):
- hf = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
- tok = self.tok
-
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- EXPECTED_SUMMARY_FRANCE = 'French prosecutor says he\'s not aware of any video footage from on board the plane. German daily Bild and French Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa confirms co-pilot Andreas Lubitz had battled depression.'
-
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- EXPECTED_SUMMARY_SHORTER = "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice."
-
- # The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- EXPECTED_SUMMARY_IRAN = "The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. He says the agreement limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Bergen says the most important aim of a nuclear deal is preventing a nuclear Iran."
-
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
- EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
-
- dct = tok(
- [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
- max_length=1024,
- truncation_strategy="only_first",
- padding="longest",
- truncation=True,
- return_tensors="tf",
- )
- self.assertEqual(1024, dct["input_ids"].shape[1])
- hypotheses_batch = hf.generate(
- input_ids=dct["input_ids"],
- attention_mask=dct["attention_mask"],
- )
-
- assert hypotheses_batch[:, 1].numpy().tolist() == [0, 0, 0, 0] # test force_bos_token_to_be_generated
- decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
- expected_batch = [
- EXPECTED_SUMMARY_FRANCE,
- EXPECTED_SUMMARY_SHORTER,
- EXPECTED_SUMMARY_IRAN,
- EXPECTED_SUMMARY_SUBWAY,
- ]
- assert decoded == expected_batch
-
- @cached_property
- def tok(self):
- return BartTokenizer.from_pretrained("facebook/bart-large")
-
-
-@slow
-@require_tf
-class FasterTFBartModelIntegrationTests(unittest.TestCase):
- """These tests are useful for debugging since they operate on a model with 1 encoder layer and 1 decoder layer."""
-
- @cached_property
- def tok(self):
- return BartTokenizer.from_pretrained("facebook/bart-large")
-
- @cached_property
- def xsum_1_1_model(self):
- return TFBartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-1-1")
-
- def test_xsum_1_1_generation(self):
- model = self.xsum_1_1_model
- assert model.model.decoder.embed_tokens._layer == model.model.shared
- ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.'
- EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
- dct = self.tok(ARTICLE, return_tensors="tf")
- generated_ids = model.generate(**dct, num_beams=4)
- result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
- assert result == EXPECTED
-
- def test_xsum_1_1_batch_generation(self):
- batch = self.tok(
- [
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
- ],
- return_tensors="tf",
- padding="longest",
- truncation=True,
- )
- generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
- result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
- assert (
- result[0]
- == " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court."
- )
- assert (
- result[1]
- == " An investigation into the crash that killed at least 10 people in the French capital has been released by the French police investigating the crash."
- )
-
- def test_encoder_equiv(self):
- batch = self.tok(
- [
- 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.',
- 'The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.',
- ],
- return_tensors="tf",
- padding="longest",
- truncation=True,
- )
- features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
-
- expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
- assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py
index 9fba62815b0134..65ef9416cb7116 100644
--- a/tests/deepspeed/test_deepspeed.py
+++ b/tests/deepspeed/test_deepspeed.py
@@ -20,10 +20,12 @@
import unittest
from copy import deepcopy
+import datasets
+
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
-from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
+from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
from transformers.testing_utils import (
CaptureLogger,
CaptureStd,
@@ -159,6 +161,12 @@ def setUp(self):
MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
+ def tearDown(self):
+ super().tearDown()
+
+ # reset the ds config global so that tests state doesn't leak
+ unset_hf_deepspeed_config()
+
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
@@ -195,28 +203,7 @@ def test_init_zero3_fp16(self):
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
-@require_deepspeed
-@require_torch_gpu
-class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
- """
-
- This class is for testing directly via get_regression_trainer
-
- It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
- which we can re-use here.
-
- Important: this class' setup can only work with a single gpu because it runs within the current
- pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
-
- Note: if any of the tests of this class get run there will be at least one gpu occupied by them
- until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
- won't be released until this pytest worker exits.
-
- This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
- processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
- is not a bug.
- """
-
+class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):
super().setUp()
@@ -248,10 +235,39 @@ def setUp(self):
zero3=config_zero3,
)
+ def tearDown(self):
+ super().tearDown()
+
+ # reset the ds config global so that tests state doesn't leak
+ unset_hf_deepspeed_config()
+
def get_config_dict(self, stage):
# As some tests modify the dict, always make a copy
return deepcopy(self.ds_config_dict[stage])
+
+@require_deepspeed
+@require_torch_gpu
+class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
+ """
+
+ This class is for testing directly via get_regression_trainer
+
+ It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
+ which we can re-use here.
+
+ Important: this class' setup can only work with a single gpu because it runs within the current
+ pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
+
+ Note: if any of the tests of this class get run there will be at least one gpu occupied by them
+ until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
+ won't be released until this pytest worker exits.
+
+ This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
+ processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
+ is not a bug.
+ """
+
# --- These tests are enough to run on one of zero stages --- #
def test_hf_ds_config_mismatch(self):
@@ -522,7 +538,7 @@ def test_gradient_accumulation(self, stage, dtype):
# see the note above how to get identical loss on a small bs
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
- def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
+ def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
@@ -534,7 +550,8 @@ def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
else:
raise ValueError(f"unknown stage {stage}")
- ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
+ if dtype == "bf16":
+ ds_file_list.append("bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
@@ -578,7 +595,7 @@ def test_save_checkpoints(self, stage, dtype):
trainer.train()
total = int(self.n_epochs * 64 / self.batch_size)
- self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
+ self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage, dtype)
@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_errors(self, stage, dtype):
@@ -724,6 +741,94 @@ def test_config_object(self):
self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
+ @parameterized.expand(params, name_func=parameterized_custom_name_func)
+ def test_load_best_model(self, stage, dtype):
+ # Test that forced deepspeed reinit doesn't break the model. the forced re-init after
+ # loading the best model in Trainer is there to workaround this bug in Deepspeed
+ # https://github.com/microsoft/DeepSpeed/issues/1612
+ #
+ # The test is derived from a repro script submitted in this Issue:
+ # https://github.com/huggingface/transformers/issues/17114
+ #
+ # One additional feature of this test is that we use a non-AdamW optimizer to test that
+ # deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
+ # correctly
+
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer # noqa
+
+ output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
+
+ ds_config_dict = self.get_config_dict(stage)
+ del ds_config_dict["optimizer"] # will use HF Trainer optimizer
+ del ds_config_dict["scheduler"] # will use HF Trainer scheduler
+ # must use this setting to get the reload path exercised
+ ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
+
+ with mockenv_context(**self.dist_env_1_gpu):
+
+ args_dict = {
+ "per_gpu_train_batch_size": 1,
+ "per_gpu_eval_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "learning_rate": 1e-4,
+ "num_train_epochs": 1,
+ "do_train": True,
+ "do_eval": True,
+ "optim": "adafactor",
+ "evaluation_strategy": "steps",
+ "eval_steps": 1,
+ "save_strategy": "steps",
+ "save_steps": 1,
+ "load_best_model_at_end": True,
+ "max_steps": 1,
+ "deepspeed": ds_config_dict,
+ }
+
+ training_args = TrainingArguments(output_dir, **args_dict)
+ tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
+ model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
+
+ def _add_eos_to_examples(example):
+ example["input_text"] = f"question: {example['question']} context: {example['context']}"
+ example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
+ return example
+
+ def _convert_to_features(example_batch):
+ input_encodings = tokenizer.batch_encode_plus(
+ example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
+ )
+ target_encodings = tokenizer.batch_encode_plus(
+ example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
+ )
+
+ encodings = {
+ "input_ids": input_encodings["input_ids"],
+ "attention_mask": input_encodings["attention_mask"],
+ "labels": target_encodings["input_ids"],
+ }
+
+ return encodings
+
+ def get_dataset():
+ data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
+ data_files = dict(train=data_file, validation=data_file)
+ raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
+ train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
+ valid_dataset = deepcopy(train_dataset)
+ return train_dataset, valid_dataset
+
+ train_dataset, eval_dataset = get_dataset()
+
+ trainer = Trainer(
+ model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+ trainer.train() # crash 1 was here
+ trainer.evaluate() # crash 2 was here
+
@slow
@require_deepspeed
@@ -1034,50 +1139,3 @@ def test_clm_from_config_zero3_fp16(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
-
- @parameterized.expand(params, name_func=parameterized_custom_name_func)
- def test_load_best_model(self, stage, dtype):
- # this test exercises --load_best_model_at_end - the key is being able to resume after some training
-
- data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
- output_dir = self.get_auto_remove_tmp_dir()
- args = f"""
- --model_name_or_path {T5_TINY}
- --tokenizer_name {T5_TINY}
- --train_file {data_dir}/train.json
- --validation_file {data_dir}/val.json
- --output_dir {output_dir}
- --overwrite_output_dir
- --source_lang en
- --target_lang ro
- --do_train
- --max_train_samples 3
- --do_eval
- --max_eval_samples 1
- --logging_strategy steps
- --logging_steps 1
- --evaluation_strategy steps
- --eval_steps 1
- --save_strategy steps
- --save_steps 1
- --load_best_model_at_end
- --per_device_train_batch_size 1
- --per_device_eval_batch_size 1
- --num_train_epochs 1
- --report_to none
- """.split()
- args.extend(["--source_prefix", "translate English to Romanian: "])
-
- args.extend([f"--{dtype}"])
-
- ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
- script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
- launcher = get_launcher(distributed=False)
-
- cmd = launcher + script + args + ds_args
- # keep for quick debug
- # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
- with CaptureStd() as cs:
- execute_subprocess_async(cmd, env=self.get_env())
- # enough to test it didn't fail
- self.assertIn("DeepSpeed info", cs.out)
diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py
index 905d40eadd5da3..f5b43f4c1bcca3 100644
--- a/tests/deepspeed/test_model_zoo.py
+++ b/tests/deepspeed/test_model_zoo.py
@@ -24,6 +24,7 @@
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
+ get_tests_dir,
require_deepspeed,
require_torch_gpu,
slow,
@@ -41,51 +42,100 @@
set_seed(42)
+FIXTURE_DIRECTORY = get_tests_dir("fixtures")
+ROOT_DIRECTORY = os.path.join(dirname(get_tests_dir()))
+DS_TESTS_DIRECTORY = dirname(os.path.abspath(__file__))
+
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
-# translation
-FSMT_TINY = "stas/tiny-wmt19-en-de"
-BART_TINY = "sshleifer/bart-tiny-random"
T5_SMALL = "t5-small"
-T5_TINY = "patrickvonplaten/t5-tiny-random"
-MBART_TINY = "sshleifer/tiny-mbart"
-MARIAN_TINY = "sshleifer/tiny-marian-en-de"
-
-# summarization
-PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
-# causal lm
+# *** Working Models ***
+ALBERT_TINY = "hf-internal-testing/tiny-albert"
+BART_TINY = "sshleifer/bart-tiny-random"
+BERT_TINY = "hf-internal-testing/tiny-bert"
+BIGBIRD_PEGASUS_TINY = "hf-internal-testing/tiny-random-bigbird_pegasus"
+BIG_BIRD_TINY = "hf-internal-testing/tiny-random-big_bird"
+BLENDERBOT_TINY = "hf-internal-testing/tiny-random-blenderbot"
+BLOOM_TINY = "bigscience/bigscience-small-testing"
+DEBERTA_TINY = "hf-internal-testing/tiny-random-deberta"
+DEBERTA_V2_TINY = "hf-internal-testing/tiny-random-deberta-v2"
+DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
+ELECTRA_TINY = "hf-internal-testing/tiny-electra"
+FLAUBERT_TINY = "hf-internal-testing/tiny-random-flaubert"
+FSMT_TINY = "stas/tiny-wmt19-en-de"
+FUNNEL_TINY = "hf-internal-testing/tiny-random-funnel"
GPT2_TINY = "sshleifer/tiny-gpt2"
+GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
+GPT_NEO_TINY = "hf-internal-testing/tiny-random-gpt_neo"
+LAYOUTLM_TINY = "hf-internal-testing/tiny-layoutlm"
+LED_TINY = "hf-internal-testing/tiny-random-led"
+LONGFORMER_TINY = "hf-internal-testing/tiny-random-longformer"
+M2M_100_TINY = "stas/tiny-m2m_100" # hf tiny model is unsuitable
+MARIAN_TINY = "sshleifer/tiny-marian-en-de"
+MBART_TINY = "sshleifer/tiny-mbart"
+MOBILEBERT_TINY = "hf-internal-testing/tiny-random-mobilebert"
+MPNET_TINY = "hf-internal-testing/tiny-random-mpnet"
+PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
+PROPHETNET_TINY = "hf-internal-testing/tiny-random-prophetnet"
+ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
+SQUEEZEBERT_TINY = "hf-internal-testing/tiny-random-squeezebert"
+T5_TINY = "patrickvonplaten/t5-tiny-random"
+T5_V1_TINY = "hf-internal-testing/tiny-random-t5-v1.1"
+VIT_TINY = "hf-internal-testing/tiny-random-vit"
XLM_ROBERTA_TINY = "hf-internal-testing/tiny-xlm-roberta"
+XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
-# question-answering
-ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
-# masked lm
-DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
-ELECTRA_TINY = "hf-internal-testing/tiny-electra"
+# *** To Fix ***
-# classification
-XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
-BERT_TINY = "hf-internal-testing/tiny-bert"
-FIXTURE_DIRECTORY = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures")
-ROOT_DIRECTORY = os.path.join(dirname(dirname(dirname(os.path.abspath(__file__)))))
+# *** tiny model issues ***
+# missing model files:
+MT5_TINY = "hf-internal-testing/tiny-random-mt5"
+CAMEMBERT_TINY = "hf-internal-testing/tiny-random-camembert"
+OPENAI_GPT_TINY = "hf-internal-testing/tiny-random-openai-gpt"
+
+# missing tokenizer files
+CONVBERT_TINY = "hf-internal-testing/tiny-random-convbert"
+LAYOUTLMV2_TINY = "hf-internal-testing/tiny-random-layoutlmv2"
+HUBERT_TINY = "hf-internal-testing/tiny-random-hubert"
+
+# issues with tokenizer
+CTRL_TINY = "hf-internal-testing/tiny-random-ctrl"
+TRANSFO_XL_TINY = "hf-internal-testing/tiny-random-transfo-xl" # same as ctrl
+
+# other issues with tiny models
+IBERT_TINY = "hf-internal-testing/tiny-random-ibert" # multiple issues with either mlm/qa/clas
+REFORMER_TINY = "hf-internal-testing/tiny-random-reformer" # multiple issues with either mlm/qa/clas
-# TODO: to add:
-# albert
-# deberta
-# funnel
-# longformer
-# dpr
-# gpt_neo
-# camembert
-# deberta-v2
-# m2m_100
-# tapas
-# vit
-# big_bird
+# *** Lacking official examples to test with ***
+# or not working with examples
+DPR_TINY = "hf-internal-testing/tiny-random-dpr"
+# - "dpr" examples/research_projects/rag-end2end-retriever/
+RAG_TINY = "hf-internal-testing/tiny-random-rag"
+# - "rag" research_projects
+LUKE_TINY = ""
+# - "luke" Entities classes - no plan to make such example
+LXMERT_TINY = "hf-internal-testing/tiny-random-lxmert"
+# - "lxmert" doesn't work with run_qa.py
+CLIP_TINY = "hf-internal-testing/tiny-random-clip"
+# - "clip" nothing under pytorch examples - XXX: Suraj is working on adding some - check by end of Sep
+SPEECH_TO_TEXT_TINY = "hf-internal-testing/tiny-random-speech_to_text"
+# - "speech_to_text", nothing under pytorch examples
+
+
+# *** Reactive mode ***
+# models with low usage, unstable API, things about to change - do nothing about the following until someone runs into a problem
+TAPAS_TINY = "hf-internal-testing/tiny-random-tapas"
+# additional notes on tapas
+# 1. requires torch_scatter - skip if it's not installed?
+# 2. "Table must be of type pd.DataFrame" failure
+
+
+# TODO: new models to add:
+#
def get_launcher(distributed=False):
@@ -112,35 +162,69 @@ def make_task_cmds():
--overwrite_output_dir
""".split()
- # XXX: try to cover as many models as possible once (it's enough to run on one task per model)
+ # try to cover as many models as possible once (it's enough to run on one task per model)
# but need a tiny model for each
#
- # should have T5_TINY, etc. global var defined
+ # should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models = dict(
trans=[
"bart",
"fsmt",
+ "m2m_100",
"marian",
"mbart",
"t5",
+ "t5_v1",
+ # "mt5", missing model files
],
sum=[
"pegasus",
],
clm=[
+ "big_bird",
+ "bigbird_pegasus",
+ "blenderbot",
+ "bloom",
"gpt2",
+ "gpt_neo",
+ "gptj",
"xlm-roberta",
+ "prophetnet",
+ # "camembert", missing model files
],
mlm=[
- "electra",
+ "albert",
+ "deberta",
+ "deberta-v2",
"distilbert",
+ "electra",
+ "flaubert",
+ "funnel",
+ "layoutlm",
+ # "reformer", # multiple issues with either mlm/qa/clas
],
qa=[
+ "led",
+ "longformer",
+ "mobilebert",
+ "mpnet",
"roberta",
+ "squeezebert",
+ # "convbert", # missing tokenizer files
+ # "layoutlmv2", missing model files
],
clas=[
"bert",
"xlnet",
+ # "hubert", # missing tokenizer files
+ # "ibert", # multiple issues with either mlm/qa/clas
+ # "transfo-xl", # tokenizer issues as ctrl
+ # "ctrl", # tokenizer issues
+ # "openai-gpt", missing model files
+ # "tapas", multiple issues
+ ],
+ img_clas=[
+ "vit",
],
)
@@ -179,6 +263,13 @@ def make_task_cmds():
--max_seq_length 12
--task_name MRPC
""",
+ img_clas=f"""
+ {scripts_dir}/image-classification/run_image_classification.py
+ --dataset_name hf-internal-testing/cats_vs_dogs_sample
+ --remove_unused_columns False
+ --max_steps 10
+ --feature_extractor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json
+ """,
)
launcher = get_launcher(distributed=True)
diff --git a/tests/deepspeed/vit_feature_extractor.json b/tests/deepspeed/vit_feature_extractor.json
new file mode 100644
index 00000000000000..bfe5a331249fa8
--- /dev/null
+++ b/tests/deepspeed/vit_feature_extractor.json
@@ -0,0 +1,4 @@
+{
+ "feature_extractor_type": "ViTFeatureExtractor",
+ "size": 30
+}
diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py
index 3d88ebda455984..64c244ae8ed2ee 100644
--- a/tests/extended/test_trainer_ext.py
+++ b/tests/extended/test_trainer_ext.py
@@ -105,6 +105,7 @@ def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
# test --sharded_ddp w/o --fp16
+ @unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp(self):
@@ -118,6 +119,7 @@ def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero_dp_2 w/o --fp16
+ @unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
@@ -278,7 +280,8 @@ def train_and_return_metrics(optim: str) -> Tuple[int, float]:
self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
- f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
+ f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were"
+ f" {gpu_total_mem_diff_bytes}",
)
def run_trainer(
diff --git a/tests/generation/test_generation_beam_search.py b/tests/generation/test_generation_beam_search.py
index 3971dcc79c35a7..885cefa62cbd51 100644
--- a/tests/generation/test_generation_beam_search.py
+++ b/tests/generation/test_generation_beam_search.py
@@ -126,7 +126,11 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
- beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
+ beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
+ beam_indices = tuple(tuple(b) for b in beam_indices)
+ beam_scorer.process(
+ input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
+ )
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
@@ -136,7 +140,7 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
- input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
+ input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
@@ -161,10 +165,15 @@ def cut_expected_tensor(tensor):
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
+ expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
- input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
+ input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
+ )
+ self.parent.assertListEqual(
+ expected_beam_indices + [next_indices[batch_idx, 1].item()],
+ torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
@@ -188,6 +197,8 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
+ beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
+ beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
@@ -196,6 +207,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
+ beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
@@ -225,6 +237,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
+ beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
@@ -394,7 +407,7 @@ def cut_expected_tensor(tensor):
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
- input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
+ input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
def check_constrained_beam_scorer_finalize(
@@ -464,7 +477,7 @@ def check_constrained_beam_scorer_finalize(
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
- for (output, constraint) in [(s, c) for s in sequences for c in constraints]:
+ for output, constraint in [(s, c) for s in sequences for c in constraints]:
forced_token_ids = constraint.token_ids
if isinstance(forced_token_ids[0], list):
# disjunctive case
diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py
index 9fb8e83fccd72a..be60335ef2f845 100644
--- a/tests/generation/test_generation_tf_logits_process.py
+++ b/tests/generation/test_generation_tf_logits_process.py
@@ -75,6 +75,7 @@ def test_min_length_dist_processor(self, use_xla):
@parameterized.expand([(False,), (True,)])
def test_temperature_dist_warper(self, use_xla):
input_ids = None
+ cur_len = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
@@ -94,8 +95,8 @@ def test_temperature_dist_warper(self, use_xla):
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
- warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
- warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
+ warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
+ warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
# uniform distribution stays uniform
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
@@ -142,6 +143,7 @@ def test_repetition_penalty_dist_process(self, use_xla):
@parameterized.expand([(False,), (True,)])
def test_top_k_dist_warper(self, use_xla):
input_ids = None
+ cur_len = None
vocab_size = 10
batch_size = 2
@@ -153,7 +155,7 @@ def test_top_k_dist_warper(self, use_xla):
if use_xla:
top_k_warp = tf.function(top_k_warp, jit_compile=True)
- scores = top_k_warp(input_ids, ramp_logits)
+ scores = top_k_warp(input_ids, ramp_logits, cur_len)
# check that correct tokens are filtered
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
@@ -167,12 +169,12 @@ def test_top_k_dist_warper(self, use_xla):
if use_xla:
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
- scores = top_k_warp_safety_check(input_ids, logits)
+ scores = top_k_warp_safety_check(input_ids, logits, cur_len)
# uniform dist is not changed
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
- scores = top_k_warp_safety_check(input_ids, ramp_logits)
+ scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
@@ -180,6 +182,7 @@ def test_top_k_dist_warper(self, use_xla):
@parameterized.expand([(False,), (True,)])
def test_top_p_dist_warper(self, use_xla):
input_ids = None
+ cur_len = None
vocab_size = 10
batch_size = 2
@@ -189,7 +192,7 @@ def test_top_p_dist_warper(self, use_xla):
top_p_warp = TFTopPLogitsWarper(0.7)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
- filtered_dist = tf.exp(top_p_warp(input_ids, dist))
+ filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
@@ -208,7 +211,7 @@ def test_top_p_dist_warper(self, use_xla):
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
if use_xla:
top_p_warp = tf.function(top_p_warp, jit_compile=True)
- filtered_dist = top_p_warp(input_ids, ramp_logits)
+ filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
# 2.
@@ -242,7 +245,8 @@ def test_no_repeat_ngram_dist_processor(self):
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
)
- def test_no_bad_words_dist_processor(self):
+ @parameterized.expand([(False,), (True,)])
+ def test_no_bad_words_dist_processor(self, use_xla):
vocab_size = 5
batch_size = 2
eos_token_id = 4
@@ -255,6 +259,8 @@ def test_no_bad_words_dist_processor(self):
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
+ if use_xla:
+ no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
@@ -322,7 +328,9 @@ def test_forced_eos_token_logits_processor(self, use_xla):
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
- def test_processor_list(self):
+ @parameterized.expand([(False,), (True,)])
+ def test_processor_list(self, use_xla):
+ # TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
batch_size = 4
cur_len = 10
vocab_size = 15
@@ -341,16 +349,24 @@ def test_processor_list(self):
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TFTopKLogitsWarper(3)
top_p_warp = TFTopPLogitsWarper(0.8)
- no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
+ # no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
+ if use_xla:
+ min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
+ temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
+ rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
+ top_k_warp = tf.function(top_k_warp, jit_compile=True)
+ top_p_warp = tf.function(top_p_warp, jit_compile=True)
+ # no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
+ no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
# no processor list
scores = min_dist_proc(input_ids, scores, cur_len)
- scores = temp_dist_warp(input_ids, scores)
+ scores = temp_dist_warp(input_ids, scores, cur_len)
scores = rep_penalty_proc(input_ids, scores, cur_len)
- scores = top_k_warp(input_ids, scores)
- scores = top_p_warp(input_ids, scores)
- scores = no_repeat_proc(input_ids, scores, cur_len)
+ scores = top_k_warp(input_ids, scores, cur_len)
+ scores = top_p_warp(input_ids, scores, cur_len)
+ # scores = no_repeat_proc(input_ids, scores, cur_len)
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
# with processor list
@@ -361,11 +377,11 @@ def test_processor_list(self):
rep_penalty_proc,
top_k_warp,
top_p_warp,
- no_repeat_proc,
+ # no_repeat_proc,
no_bad_words_dist_proc,
]
)
- scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
+ scores_comp = processor(input_ids, scores_comp, cur_len)
# remove inf
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py
index 6006dbe21cdf01..c52a72450eb195 100644
--- a/tests/generation/test_generation_utils.py
+++ b/tests/generation/test_generation_utils.py
@@ -1654,8 +1654,12 @@ def test_diverse_beam_search(self):
self.assertListEqual(
generated_text,
[
- "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
- "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
+ "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the"
+ " middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle"
+ " name, as well as his father's first. It is the first baby for both of them.",
+ "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the"
+ " first child for both. The couple announced the pregnancy in January. The name Silas is the middle"
+ " name of Timberlake's maternal grandfather. It's also his own middle name.",
],
)
@@ -2318,6 +2322,94 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
+ @slow
+ def test_transition_scores_early_stopping(self):
+ # This is an aggressive test that makes sure that `beam_search's`
+ # transition scores are computed correctly for varying `num_return_sequences`,
+ # `num_beams` and `batch_size > 1`
+ # 2 x input_ids for "question: How are you? \n context: I had a long day, "
+ input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
+ torch_device
+ )
+
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
+
+ result = model.generate(
+ input_ids,
+ max_length=10,
+ return_dict_in_generate=True,
+ output_scores=True,
+ forced_eos_token_id=model.config.eos_token_id,
+ num_beams=4,
+ do_sample=False,
+ num_return_sequences=3,
+ length_penalty=0.0,
+ )
+
+ transition_scores = model.compute_transition_beam_scores(
+ sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
+ )
+
+ sum_transition_scores = torch.sum(transition_scores, dim=1)
+
+ self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
+
+ def test_log_scores_sample_decoder_only(self):
+ articles = ["I need input_ids to generate", "Short and"]
+ tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+
+ inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
+
+ result = model.generate(
+ **inputs,
+ max_length=15,
+ return_dict_in_generate=True,
+ do_sample=False,
+ output_scores=True,
+ )
+
+ # decoder-only starts generating from `input_ids`
+ begin_generation = inputs.input_ids.shape[-1]
+
+ gen_sequences = result.sequences[:, begin_generation:]
+ probs = torch.stack(result.scores, dim=1).softmax(-1)
+
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
+ expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
+
+ self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
+
+ def test_log_scores_sample_encoder_decoder(self):
+ articles = ["I need input_ids to generate", "Short and"]
+ tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
+
+ inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
+
+ result = model.generate(
+ **inputs,
+ max_length=3,
+ return_dict_in_generate=True,
+ do_sample=False,
+ num_beams=1,
+ output_scores=True,
+ )
+
+ # encoder-decoder has one decoder_start_token_id by default
+ begin_generation = 1
+
+ gen_sequences = result.sequences[:, begin_generation:]
+ probs = torch.stack(result.scores, dim=1).softmax(-1)
+
+ gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
+ expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
+
+ self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
+
@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
@@ -2362,8 +2454,8 @@ def test_beam_search_example_integration(self):
@slow
def test_constrained_beam_search(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
@@ -2392,14 +2484,15 @@ def test_constrained_beam_search(self):
self.assertListEqual(
generated_text,
[
- "The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
+ "The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them"
+ " off. They had no idea what to do",
],
)
@slow
def test_constrained_beam_search_mixed(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
@@ -2437,8 +2530,8 @@ def test_constrained_beam_search_mixed(self):
@slow
def test_constrained_beam_search_mixed_mixin(self):
- model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
@@ -2540,8 +2633,8 @@ def test_constrained_beam_search_example_integration(self):
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
+ model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
diff --git a/tests/auto/__init__.py b/tests/models/__init__.py
similarity index 100%
rename from tests/auto/__init__.py
rename to tests/models/__init__.py
diff --git a/tests/bart/__init__.py b/tests/models/albert/__init__.py
similarity index 100%
rename from tests/bart/__init__.py
rename to tests/models/albert/__init__.py
diff --git a/tests/albert/test_modeling_albert.py b/tests/models/albert/test_modeling_albert.py
similarity index 98%
rename from tests/albert/test_modeling_albert.py
rename to tests/models/albert/test_modeling_albert.py
index 125ba314ddc39e..77496699d427b0 100644
--- a/tests/albert/test_modeling_albert.py
+++ b/tests/models/albert/test_modeling_albert.py
@@ -20,8 +20,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/albert/test_modeling_flax_albert.py b/tests/models/albert/test_modeling_flax_albert.py
similarity index 98%
rename from tests/albert/test_modeling_flax_albert.py
rename to tests/models/albert/test_modeling_flax_albert.py
index 11e971684e4318..802952e52cbd85 100644
--- a/tests/albert/test_modeling_flax_albert.py
+++ b/tests/models/albert/test_modeling_flax_albert.py
@@ -19,7 +19,7 @@
from transformers import AlbertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/albert/test_modeling_tf_albert.py b/tests/models/albert/test_modeling_tf_albert.py
similarity index 98%
rename from tests/albert/test_modeling_tf_albert.py
rename to tests/models/albert/test_modeling_tf_albert.py
index 7eacc1f32a472f..ad10228a518225 100644
--- a/tests/albert/test_modeling_tf_albert.py
+++ b/tests/models/albert/test_modeling_tf_albert.py
@@ -20,8 +20,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/albert/test_tokenization_albert.py b/tests/models/albert/test_tokenization_albert.py
similarity index 96%
rename from tests/albert/test_tokenization_albert.py
rename to tests/models/albert/test_tokenization_albert.py
index 2421da49274c50..5459917775d992 100644
--- a/tests/albert/test_tokenization_albert.py
+++ b/tests/models/albert/test_tokenization_albert.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import AlbertTokenizer, AlbertTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/spiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
diff --git a/tests/barthez/__init__.py b/tests/models/auto/__init__.py
similarity index 100%
rename from tests/barthez/__init__.py
rename to tests/models/auto/__init__.py
diff --git a/tests/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py
similarity index 85%
rename from tests/auto/test_configuration_auto.py
rename to tests/models/auto/test_configuration_auto.py
index f07bb428347a7c..2695082c412d07 100644
--- a/tests/auto/test_configuration_auto.py
+++ b/tests/models/auto/test_configuration_auto.py
@@ -14,6 +14,7 @@
# limitations under the License.
import importlib
+import json
import os
import sys
import tempfile
@@ -24,15 +25,15 @@
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.roberta.configuration_roberta import RobertaConfig
-from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
+from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, get_tests_dir
-sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
+sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
-SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/dummy-config.json")
+SAMPLE_ROBERTA_CONFIG = get_tests_dir("fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase):
@@ -57,14 +58,14 @@ def test_config_for_model_str(self):
self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self):
- """
- In cases where config.json doesn't include a model_type,
- perform a few safety checks on the config mapping's order.
- """
- # no key string should be included in a later key string (typical failure case)
- keys = list(CONFIG_MAPPING.keys())
- for i, key in enumerate(keys):
- self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # This model name contains bert and roberta, but roberta ends up being picked.
+ folder = os.path.join(tmp_dir, "fake-roberta")
+ os.makedirs(folder, exist_ok=True)
+ with open(os.path.join(folder, "config.json"), "w") as f:
+ f.write(json.dumps({}))
+ config = AutoConfig.from_pretrained(folder)
+ self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self):
try:
diff --git a/tests/auto/test_feature_extraction_auto.py b/tests/models/auto/test_feature_extraction_auto.py
similarity index 90%
rename from tests/auto/test_feature_extraction_auto.py
rename to tests/models/auto/test_feature_extraction_auto.py
index b0c11c517ab75b..e9d044e8daac07 100644
--- a/tests/auto/test_feature_extraction_auto.py
+++ b/tests/models/auto/test_feature_extraction_auto.py
@@ -14,7 +14,6 @@
# limitations under the License.
import json
-import os
import sys
import tempfile
import unittest
@@ -28,20 +27,18 @@
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
)
-from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
+from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, get_tests_dir
-sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
+sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
-SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures")
-SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "../fixtures/dummy_feature_extractor_config.json"
-)
-SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/dummy-config.json")
+SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
+SAMPLE_FEATURE_EXTRACTION_CONFIG = get_tests_dir("fixtures/dummy_feature_extractor_config.json")
+SAMPLE_CONFIG = get_tests_dir("fixtures/dummy-config.json")
class AutoFeatureExtractorTest(unittest.TestCase):
diff --git a/tests/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py
similarity index 99%
rename from tests/auto/test_modeling_auto.py
rename to tests/models/auto/test_modeling_auto.py
index 02ecb08e1e2f68..3731d70f5bb5af 100644
--- a/tests/auto/test_modeling_auto.py
+++ b/tests/models/auto/test_modeling_auto.py
@@ -32,7 +32,7 @@
from ..bert.test_modeling_bert import BertModelTester
-sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
+sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
diff --git a/tests/auto/test_modeling_flax_auto.py b/tests/models/auto/test_modeling_flax_auto.py
similarity index 100%
rename from tests/auto/test_modeling_flax_auto.py
rename to tests/models/auto/test_modeling_flax_auto.py
diff --git a/tests/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py
similarity index 100%
rename from tests/auto/test_modeling_tf_auto.py
rename to tests/models/auto/test_modeling_tf_auto.py
diff --git a/tests/auto/test_modeling_tf_pytorch.py b/tests/models/auto/test_modeling_tf_pytorch.py
similarity index 100%
rename from tests/auto/test_modeling_tf_pytorch.py
rename to tests/models/auto/test_modeling_tf_pytorch.py
diff --git a/tests/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py
similarity index 96%
rename from tests/auto/test_processor_auto.py
rename to tests/models/auto/test_processor_auto.py
index 02fa8696c4a156..26122e6164ab1a 100644
--- a/tests/auto/test_processor_auto.py
+++ b/tests/models/auto/test_processor_auto.py
@@ -36,12 +36,12 @@
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
-from transformers.testing_utils import PASS, USER, is_staging_test
+from transformers.testing_utils import PASS, USER, get_tests_dir, is_staging_test
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available
-sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
+sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
@@ -49,12 +49,9 @@
from test_module.custom_tokenization import CustomTokenizer # noqa E402
-SAMPLE_PROCESSOR_CONFIG = os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "../fixtures/dummy_feature_extractor_config.json"
-)
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/vocab.json")
-
-SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures")
+SAMPLE_PROCESSOR_CONFIG = get_tests_dir("fixtures/dummy_feature_extractor_config.json")
+SAMPLE_VOCAB = get_tests_dir("fixtures/vocab.json")
+SAMPLE_PROCESSOR_CONFIG_DIR = get_tests_dir("fixtures")
class AutoFeatureExtractorTest(unittest.TestCase):
diff --git a/tests/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py
similarity index 99%
rename from tests/auto/test_tokenization_auto.py
rename to tests/models/auto/test_tokenization_auto.py
index 57041e5830296c..1e1abb9245842c 100644
--- a/tests/auto/test_tokenization_auto.py
+++ b/tests/models/auto/test_tokenization_auto.py
@@ -53,7 +53,7 @@
)
-sys.path.append(str(Path(__file__).parent.parent / "utils"))
+sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_tokenization import CustomTokenizer # noqa E402
diff --git a/tests/bartpho/__init__.py b/tests/models/bart/__init__.py
similarity index 100%
rename from tests/bartpho/__init__.py
rename to tests/models/bart/__init__.py
diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py
new file mode 100644
index 00000000000000..b36bda3b71ece3
--- /dev/null
+++ b/tests/models/bart/test_modeling_bart.py
@@ -0,0 +1,1414 @@
+# coding=utf-8
+# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch BART model. """
+
+
+import copy
+import tempfile
+import unittest
+
+import timeout_decorator # noqa
+
+from transformers import BartConfig, is_torch_available
+from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
+from transformers.utils import cached_property
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ AutoModelForSequenceClassification,
+ BartForCausalLM,
+ BartForConditionalGeneration,
+ BartForQuestionAnswering,
+ BartForSequenceClassification,
+ BartModel,
+ BartTokenizer,
+ pipeline,
+ )
+ from transformers.models.bart.modeling_bart import BartDecoder, BartEncoder, shift_tokens_right
+
+
+def prepare_bart_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = input_ids.ne(config.pad_token_id)
+ if decoder_attention_mask is None:
+ decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
+ if head_mask is None:
+ head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
+ if decoder_head_mask is None:
+ decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
+ if cross_attn_head_mask is None:
+ cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
+ return {
+ "input_ids": input_ids,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ }
+
+
+class BartModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
+ 3,
+ )
+ input_ids[:, -1] = self.eos_token_id # Eos Token
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ config = self.get_config()
+ inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
+ return config, inputs_dict
+
+ def get_config(self):
+ return BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ encoder_layers=self.num_hidden_layers,
+ decoder_layers=self.num_hidden_layers,
+ encoder_attention_heads=self.num_attention_heads,
+ decoder_attention_heads=self.num_attention_heads,
+ encoder_ffn_dim=self.intermediate_size,
+ decoder_ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ )
+
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.max_position_embeddings = 100
+ config.vocab_size = 300
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = BartModel(config=config).get_decoder().to(torch_device).eval()
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+ head_mask = inputs_dict["head_mask"]
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = ids_tensor((self.batch_size, 3), 2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def check_encoder_decoder_model_standalone(self, config, inputs_dict):
+ model = BartModel(config=config).to(torch_device).eval()
+ outputs = model(**inputs_dict)
+
+ encoder_last_hidden_state = outputs.encoder_last_hidden_state
+ last_hidden_state = outputs.last_hidden_state
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ encoder = model.get_encoder()
+ encoder.save_pretrained(tmpdirname)
+ encoder = BartEncoder.from_pretrained(tmpdirname).to(torch_device)
+
+ encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
+ 0
+ ]
+
+ self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ decoder = model.get_decoder()
+ decoder.save_pretrained(tmpdirname)
+ decoder = BartDecoder.from_pretrained(tmpdirname).to(torch_device)
+
+ last_hidden_state_2 = decoder(
+ input_ids=inputs_dict["decoder_input_ids"],
+ attention_mask=inputs_dict["decoder_attention_mask"],
+ encoder_hidden_states=encoder_last_hidden_state,
+ encoder_attention_mask=inputs_dict["attention_mask"],
+ )[0]
+
+ self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
+
+
+@require_torch
+class BartHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ input_ids = torch.tensor(
+ [
+ [71, 82, 18, 33, 46, 91, 2],
+ [68, 34, 26, 58, 30, 82, 2],
+ [5, 97, 17, 39, 94, 40, 2],
+ [76, 83, 94, 25, 70, 78, 2],
+ [87, 59, 41, 35, 48, 66, 2],
+ [55, 13, 16, 58, 5, 2, 1], # note padding
+ [64, 27, 31, 51, 12, 75, 2],
+ [52, 64, 86, 17, 83, 39, 2],
+ [48, 61, 9, 24, 71, 82, 2],
+ [26, 1, 60, 48, 22, 13, 2],
+ [21, 5, 62, 28, 14, 76, 2],
+ [45, 98, 37, 86, 59, 48, 2],
+ [70, 70, 50, 9, 28, 0, 2],
+ ],
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ batch_size = input_ids.shape[0]
+ config = BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ return config, input_ids, batch_size
+
+ def test_sequence_classification_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ labels = _long_tensor([2] * batch_size).to(torch_device)
+ model = BartForSequenceClassification(config)
+ model.to(torch_device)
+ outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
+ expected_shape = torch.Size((batch_size, config.num_labels))
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ def test_question_answering_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
+ model = BartForQuestionAnswering(config)
+ model.to(torch_device)
+ outputs = model(
+ input_ids=input_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+
+ self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
+ self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ @timeout_decorator.timeout(1)
+ def test_lm_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
+ lm_model = BartForConditionalGeneration(config)
+ lm_model.to(torch_device)
+ outputs = lm_model(input_ids=input_ids, labels=lm_labels)
+ expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+ self.assertIsInstance(outputs["loss"].item(), float)
+
+ def test_lm_uneven_forward(self):
+ config = BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=14,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=8,
+ decoder_ffn_dim=8,
+ max_position_embeddings=48,
+ )
+ lm_model = BartForConditionalGeneration(config).to(torch_device)
+ context = torch.tensor(
+ [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long
+ )
+ summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long)
+ outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
+ expected_shape = (*summary.shape, config.vocab_size)
+ self.assertEqual(outputs["logits"].shape, expected_shape)
+
+ def test_generate_beam_search(self):
+ input_ids = torch.tensor([[71, 82, 2], [68, 34, 2]], device=torch_device, dtype=torch.long)
+ config = BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ lm_model = BartForConditionalGeneration(config).to(torch_device)
+ lm_model.eval()
+
+ max_length = 5
+ generated_ids = lm_model.generate(
+ input_ids.clone(),
+ do_sample=True,
+ num_return_sequences=1,
+ num_beams=2,
+ no_repeat_ngram_size=3,
+ max_length=max_length,
+ )
+ self.assertEqual(generated_ids.shape, (input_ids.shape[0], max_length))
+
+ def test_shift_tokens_right(self):
+ input_ids = torch.tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=torch.long)
+ shifted = shift_tokens_right(input_ids, 1, 2)
+ n_pad_before = input_ids.eq(1).float().sum()
+ n_pad_after = shifted.eq(1).float().sum()
+ self.assertEqual(shifted.shape, input_ids.shape)
+ self.assertEqual(n_pad_after, n_pad_before - 1)
+ self.assertTrue(torch.eq(shifted[:, 0], 2).all())
+
+ @slow
+ def test_tokenization(self):
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
+ examples = [" Hello world", " DomDramg"] # need leading spaces for equality
+ fairseq_results = [
+ torch.tensor([0, 20920, 232, 2]),
+ torch.tensor([0, 11349, 495, 4040, 571, 2]),
+ ]
+ for ex, desired_result in zip(examples, fairseq_results):
+ bart_toks = tokenizer.encode(ex, return_tensors="pt").squeeze()
+ assert_tensors_close(desired_result.long(), bart_toks, prefix=ex)
+
+ def test_generate_fp16(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = BartForConditionalGeneration(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+ def test_dummy_inputs(self):
+ config, *_ = self._get_config_and_data()
+ model = BartForConditionalGeneration(config).eval().to(torch_device)
+ model(**model.dummy_inputs)
+
+ def test_resize_tokens_embeddings_more(self):
+ config, input_ids, _ = self._get_config_and_data()
+
+ def _get_embs(m):
+ return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
+
+ model = BartForConditionalGeneration(config).eval().to(torch_device)
+ input, output = _get_embs(model)
+ self.assertTrue(torch.eq(input, output).all())
+ new_vocab_size = 45
+ model.resize_token_embeddings(new_vocab_size)
+ input_new, output_new = _get_embs(model)
+ self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
+ self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
+ self.assertTrue(torch.eq(input_new, output_new).all())
+
+
+@require_torch
+class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
+ if is_torch_available()
+ else ()
+ )
+ all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
+ is_encoder_decoder = True
+ fx_compatible = True
+ test_pruning = False
+ test_missing_keys = False
+
+ def setUp(self):
+ self.model_tester = BartModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=BartConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_save_load_strict(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
+ self.assertEqual(info["missing_keys"], [])
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_encoder_decoder_model_standalone(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
+
+ # BartForSequenceClassification does not support inputs_embeds
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in (BartModel, BartForConditionalGeneration, BartForQuestionAnswering):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ model(**inputs)[0]
+
+ def test_generate_fp16(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = BartForConditionalGeneration(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+
+def assert_tensors_close(a, b, atol=1e-12, prefix=""):
+ """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
+ if a is None and b is None:
+ return True
+ try:
+ if torch.allclose(a, b, atol=atol):
+ return True
+ raise
+ except Exception:
+ pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
+ if a.numel() > 100:
+ msg = f"tensor values are {pct_different:.1%} percent different."
+ else:
+ msg = f"{a} != {b}"
+ if prefix:
+ msg = prefix + ": " + msg
+ raise AssertionError(msg)
+
+
+def _long_tensor(tok_lst):
+ return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
+
+
+@require_torch
+@slow
+class FastIntegrationTests(unittest.TestCase):
+ """These tests are useful for debugging since they operate on a model with 1 encoder layer and 1 decoder layer."""
+
+ @cached_property
+ def tok(self):
+ return BartTokenizer.from_pretrained("facebook/bart-large")
+
+ @cached_property
+ def xsum_1_1_model(self):
+ return BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-1-1")
+
+ def test_xsum_1_1_generation(self):
+ hf = self.xsum_1_1_model
+ tok = self.tok
+ ARTICLE = (
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes."
+ )
+ EXPECTED = (
+ " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
+
+ dct = tok(ARTICLE, return_tensors="pt")
+ generated_ids = hf.generate(**dct, num_beams=4)
+ result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ assert EXPECTED == result
+
+ def test_xsum_1_1_batch_generation(self):
+ # test batch
+
+ batch = self.tok(
+ [
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
+ ],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ )
+ generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
+ result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
+ assert (
+ result[0]
+ == " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
+ assert (
+ result[1]
+ == " An investigation into the crash that killed at least 10 people in the French capital has been"
+ " released by the French police investigating the crash."
+ )
+
+ def test_encoder_equiv(self):
+ # test batch
+
+ batch = self.tok(
+ [
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
+ ],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ )
+ features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
+ expected = [[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]
+ assert_tensors_close(features[0, :3, :3], torch.tensor(expected), atol=1e-3)
+
+
+@require_torch
+@require_sentencepiece
+@require_tokenizers
+class BartModelIntegrationTests(unittest.TestCase):
+ @cached_property
+ def default_tokenizer(self):
+ return BartTokenizer.from_pretrained("facebook/bart-large")
+
+ @slow
+ def test_inference_no_head(self):
+ model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = input_ids.ne(model.config.pad_token_id)
+ with torch.no_grad():
+ output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
+ expected_shape = torch.Size((1, 11, 1024))
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = torch.tensor(
+ [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
+ )
+ self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
+
+ @slow
+ def test_base_mask_filling(self):
+ pbase = pipeline(task="fill-mask", model="facebook/bart-base")
+ src_text = [" I went to the ."]
+ results = [x["token_str"] for x in pbase(src_text)]
+ assert " bathroom" in results
+
+ @slow
+ def test_large_mask_filling(self):
+ plarge = pipeline(task="fill-mask", model="facebook/bart-large")
+ src_text = [" I went to the ."]
+ results = [x["token_str"] for x in plarge(src_text)]
+ expected_results = [" bathroom", " gym", " wrong", " movies", " hospital"]
+ self.assertListEqual(results, expected_results)
+
+ @slow
+ def test_mnli_inference(self):
+ example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b])
+
+ model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli").to(
+ torch_device
+ ) # eval called in from_pre
+ attention_mask = input_ids.ne(model.config.pad_token_id)
+ # Test that model hasn't changed
+ with torch.no_grad():
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
+
+ batched_logits = outputs.logits
+ expected_shape = torch.Size((2, 3))
+ self.assertEqual(batched_logits.shape, expected_shape)
+ expected_slice = torch.tensor([[0.1907, 1.4342, -1.0289]], device=torch_device)
+ logits_arr = batched_logits[0].detach()
+
+ # Test that padding does not change results
+ input_ids_no_pad = _long_tensor([example_b[:-1]])
+ attention_mask_no_pad = input_ids_no_pad.ne(model.config.pad_token_id)
+
+ with torch.no_grad():
+ logits2 = model(input_ids=input_ids_no_pad, attention_mask=attention_mask_no_pad).logits.squeeze()
+ assert_tensors_close(batched_logits[1], logits2, atol=1e-3)
+ assert_tensors_close(expected_slice, logits_arr, atol=1e-3)
+
+ @slow
+ def test_xsum_summarization_same_as_fairseq(self):
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
+ tok = self.default_tokenizer
+
+ PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
+
+ EXPECTED_SUMMARY = (
+ "California's largest power company has begun shutting off electricity to thousands of customers in the"
+ " state."
+ )
+ dct = tok.batch_encode_plus(
+ [PGE_ARTICLE],
+ max_length=1024,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).to(torch_device)
+
+ hypotheses_batch = model.generate(
+ input_ids=dct["input_ids"],
+ attention_mask=dct["attention_mask"],
+ num_beams=2,
+ max_length=62,
+ min_length=11,
+ length_penalty=1.0,
+ no_repeat_ngram_size=3,
+ early_stopping=True,
+ decoder_start_token_id=model.config.eos_token_id,
+ )
+
+ decoded = tok.batch_decode(
+ hypotheses_batch,
+ skip_special_tokens=True,
+ )
+ self.assertEqual(EXPECTED_SUMMARY, decoded[0])
+
+ def test_xsum_config_generation_params(self):
+ config = BartConfig.from_pretrained("facebook/bart-large-xsum")
+ expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
+ config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
+ self.assertDictEqual(expected_params, config_params)
+
+ @slow
+ def test_cnn_summarization_same_as_fairseq(self):
+ hf = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
+ tok = BartTokenizer.from_pretrained("facebook/bart-large")
+
+ FRANCE_ARTICLE = ( # @noq
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+
+ # The below article tests that we don't add any hypotheses outside of the top n_beams
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
+
+ dct = tok.batch_encode_plus(
+ [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
+ max_length=1024,
+ padding="max_length",
+ truncation_strategy="only_first",
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ self.assertEqual(1024, dct["input_ids"].shape[1])
+ hypotheses_batch = hf.generate(
+ input_ids=dct["input_ids"].to(torch_device),
+ attention_mask=dct["attention_mask"].to(torch_device),
+ num_beams=2,
+ )
+ assert hypotheses_batch[:, 1].eq(0).all().item()
+
+ EXPECTED = [
+ "A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
+ "magazines claim to have found a cell phone video showing the crash. The publications say they watched "
+ "the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
+ "9525 were killed.",
+ "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
+ "jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
+ "Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
+ "move toward greater justice.",
+ "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
+ "debate that has already begun will likely result in more heat than light. He says critics have made "
+ "dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
+ "nuclear weapon.",
+ "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
+ "say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
+ "Bronx on Friday. If convicted, she faces up to four years in prison.",
+ ]
+
+ generated_summaries = tok.batch_decode(
+ hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
+ )
+ assert generated_summaries == EXPECTED
+
+
+class BartStandaloneDecoderModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ d_model=16,
+ decoder_seq_length=7,
+ is_training=True,
+ is_decoder=True,
+ use_attention_mask=True,
+ use_cache=False,
+ use_labels=True,
+ decoder_start_token_id=2,
+ decoder_ffn_dim=32,
+ decoder_layers=4,
+ encoder_attention_heads=4,
+ decoder_attention_heads=4,
+ max_position_embeddings=30,
+ is_encoder_decoder=False,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.decoder_seq_length = decoder_seq_length
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+
+ self.vocab_size = vocab_size
+ self.d_model = d_model
+ self.hidden_size = d_model
+ self.num_hidden_layers = decoder_layers
+ self.decoder_layers = decoder_layers
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_attention_heads = decoder_attention_heads
+ self.num_attention_heads = decoder_attention_heads
+ self.eos_token_id = eos_token_id
+ self.bos_token_id = bos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.use_cache = use_cache
+ self.max_position_embeddings = max_position_embeddings
+ self.is_encoder_decoder = is_encoder_decoder
+
+ self.scope = None
+ self.decoder_key_length = decoder_seq_length
+ self.base_model_out_len = 2
+ self.decoder_attention_idx = 1
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ lm_labels = None
+ if self.use_labels:
+ lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ config = BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=self.d_model,
+ encoder_layers=self.decoder_layers,
+ decoder_layers=self.decoder_layers,
+ decoder_ffn_dim=self.decoder_ffn_dim,
+ encoder_attention_heads=self.encoder_attention_heads,
+ decoder_attention_heads=self.decoder_attention_heads,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ use_cache=self.use_cache,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ max_position_embeddings=self.max_position_embeddings,
+ is_encoder_decoder=self.is_encoder_decoder,
+ )
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ) = self.prepare_config_and_inputs()
+
+ encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ lm_labels,
+ )
+
+ def create_and_check_decoder_model_past(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ):
+ config.use_cache = True
+ model = BartDecoder(config=config).to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, use_cache=True)
+ outputs_use_cache_conf = model(input_ids)
+ outputs_no_past = model(input_ids, use_cache=False)
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ past_key_values = outputs["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
+
+ def create_and_check_decoder_model_attention_mask_past(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ):
+ model = BartDecoder(config=config).to(torch_device).eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ half_seq_length = input_ids.shape[-1] // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ lm_labels,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
+ fx_comptatible = True
+ test_pruning = False
+ is_encoder_decoder = False
+
+ def setUp(
+ self,
+ ):
+ self.model_tester = BartStandaloneDecoderModelTester(self, is_training=False)
+ self.config_tester = ConfigTester(self, config_class=BartConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
+
+ def test_decoder_model_attn_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
+
+ def test_retain_grad_hidden_states_attentions(self):
+ # decoder cannot keep gradients
+ return
diff --git a/tests/bart/test_modeling_flax_bart.py b/tests/models/bart/test_modeling_flax_bart.py
similarity index 50%
rename from tests/bart/test_modeling_flax_bart.py
rename to tests/models/bart/test_modeling_flax_bart.py
index 219d41cae2b699..54a6ff4534df62 100644
--- a/tests/bart/test_modeling_flax_bart.py
+++ b/tests/models/bart/test_modeling_flax_bart.py
@@ -19,8 +19,8 @@
from transformers import BartConfig, BartTokenizer, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
@@ -420,7 +420,10 @@ def test_summarization_fast(self):
model = FlaxBartForConditionalGeneration.from_pretrained("sshleifer/distilbart-cnn-6-6")
tokenizer = BartTokenizer.from_pretrained("sshleifer/distilbart-cnn-6-6")
- input_str = "This sentence is made of three parts. Each part is important on its own. One part is about animals, the other part about planes, and the last part about housing."
+ input_str = (
+ "This sentence is made of three parts. Each part is important on its own. One part is about animals, the"
+ " other part about planes, and the last part about housing."
+ )
input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
@@ -436,14 +439,197 @@ def test_cnn_summarization_same_as_fairseq(self):
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
- FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noq
+ FRANCE_ARTICLE = ( # @noq
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
- SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
# The below article tests that we don't add any hypotheses outside of the top n_beams
- IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
- ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
dct = tokenizer.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
@@ -463,10 +649,21 @@ def test_cnn_summarization_same_as_fairseq(self):
assert (hypotheses_batch[:, 1] == 0).all().item()
EXPECTED = [
- "A French prosecutor says he is not aware of any video footage from on board the plane. Two German magazines claim to have found a cell phone video showing the crash. The publications say they watched the video, which was found by a source close to the investigation. All 150 on board the Germanwings flight were killed.",
- "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice.",
- "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. Bergen: The most misleading assertion is that the negotiations' objective at the outset was the total elimination of any nuclear program.",
- "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the Bronx on Friday. If convicted, Barrientos faces up to four years in prison.",
+ "A French prosecutor says he is not aware of any video footage from on board the plane. Two German"
+ " magazines claim to have found a cell phone video showing the crash. The publications say they watched"
+ " the video, which was found by a source close to the investigation. All 150 on board the Germanwings"
+ " flight were killed.",
+ "Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court"
+ " jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the"
+ " Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a"
+ " move toward greater justice.",
+ "U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The"
+ " debate that has already begun will likely result in more heat than light. Bergen: The most misleading"
+ " assertion is that the negotiations' objective at the outset was the total elimination of any nuclear"
+ " program.",
+ "Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors"
+ " say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the"
+ " Bronx on Friday. If convicted, Barrientos faces up to four years in prison.",
]
generated_summaries = tokenizer.batch_decode(
diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py
new file mode 100644
index 00000000000000..1e599c6b1ba1f4
--- /dev/null
+++ b/tests/models/bart/test_modeling_tf_bart.py
@@ -0,0 +1,937 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers import BartConfig, BartTokenizer, is_tf_available
+from transformers.testing_utils import require_tf, slow
+from transformers.utils import cached_property
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import TFBartForConditionalGeneration, TFBartModel
+
+
+@require_tf
+class TFBartModelTester:
+ config_cls = BartConfig
+ config_updates = {}
+ hidden_act = "gelu"
+
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+
+ def prepare_config_and_inputs_for_common(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
+ eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
+ input_ids = tf.concat([input_ids, eos_tensor], axis=1)
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ config = self.config_cls(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ encoder_layers=self.num_hidden_layers,
+ decoder_layers=self.num_hidden_layers,
+ encoder_attention_heads=self.num_attention_heads,
+ decoder_attention_heads=self.num_attention_heads,
+ encoder_ffn_dim=self.intermediate_size,
+ decoder_ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_ids=[2],
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.pad_token_id,
+ **self.config_updates,
+ )
+ inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
+ return config, inputs_dict
+
+ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = TFBartModel(config=config).get_decoder()
+ input_ids = inputs_dict["input_ids"]
+
+ input_ids = input_ids[:1, :]
+ attention_mask = inputs_dict["attention_mask"][:1, :]
+ head_mask = inputs_dict["head_mask"]
+ self.batch_size = 1
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
+
+ # append to next input_ids and
+ next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
+ next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
+
+ self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
+
+ # select random slice
+ random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
+ output_from_past_slice = output_from_past[:, :, random_slice_idx]
+
+ # test that outputs are equal for slice
+ tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
+
+
+def prepare_bart_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
+ if decoder_attention_mask is None:
+ decoder_attention_mask = tf.concat(
+ [
+ tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
+ tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
+ ],
+ axis=-1,
+ )
+ if head_mask is None:
+ head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
+ if decoder_head_mask is None:
+ decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
+ if cross_attn_head_mask is None:
+ cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
+ return {
+ "input_ids": input_ids,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ }
+
+
+@require_tf
+class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestCase):
+ all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
+ all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
+ is_encoder_decoder = True
+ test_pruning = False
+ test_onnx = True
+ onnx_min_opset = 10
+
+ def setUp(self):
+ self.model_tester = TFBartModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=BartConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
+
+ if model_class in self.all_generative_model_classes:
+ x = model.get_output_embeddings()
+ assert isinstance(x, tf.keras.layers.Layer)
+ name = model.get_bias()
+ assert isinstance(name, dict)
+ for k, v in name.items():
+ assert isinstance(v, tf.Variable)
+ else:
+ x = model.get_output_embeddings()
+ assert x is None
+ name = model.get_bias()
+ assert name is None
+
+ def test_resize_token_embeddings(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def _get_word_embedding_weight(model, embedding_layer):
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ # Here we build the word embeddings weights if not exists.
+ # And then we retry to get the attribute once built.
+ model(model.dummy_inputs)
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ return None
+
+ for model_class in self.all_model_classes:
+ for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
+ # build the embeddings
+ model = model_class(config=config)
+ old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+ old_final_logits_bias = model.get_bias()
+
+ # reshape the embeddings
+ model.resize_token_embeddings(size)
+ new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+ new_final_logits_bias = model.get_bias()
+
+ # check that the resized embeddings size matches the desired size.
+ assert_size = size if size is not None else config.vocab_size
+
+ self.assertEqual(new_input_embeddings.shape[0], assert_size)
+
+ # check that weights remain the same after resizing
+ models_equal = True
+ for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ if old_output_embeddings is not None and new_output_embeddings is not None:
+ self.assertEqual(new_output_embeddings.shape[0], assert_size)
+
+ models_equal = True
+ for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ if old_final_logits_bias is not None and new_final_logits_bias is not None:
+ old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
+ new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
+ self.assertEqual(new_final_logits_bias.shape[0], 1)
+ self.assertEqual(new_final_logits_bias.shape[1], assert_size)
+
+ models_equal = True
+ for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
+ for p1, p2 in zip(old, new):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ def test_saved_model_creation(self):
+ # This test is too long (>30sec) and makes fail the CI
+ pass
+
+
+def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
+ """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
+ if a is None and b is None:
+ return True
+ try:
+ if tf.debugging.assert_near(a, b, atol=atol):
+ return True
+ raise
+ except Exception:
+ if len(prefix) > 0:
+ prefix = f"{prefix}: "
+ raise AssertionError(f"{prefix}{a} != {b}")
+
+
+def _long_tensor(tok_lst):
+ return tf.constant(tok_lst, dtype=tf.int32)
+
+
+@require_tf
+class TFBartHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
+ input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
+ batch_size = input_ids.shape[0]
+ config = BartConfig(
+ vocab_size=self.vocab_size,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ decoder_start_token_id=2,
+ )
+ return config, input_ids, batch_size
+
+ def test_lm_forward(self):
+ config, input_ids, batch_size = self._get_config_and_data()
+ decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
+ lm_model = TFBartForConditionalGeneration(config)
+ outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
+ expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ def test_lm_uneven_forward(self):
+ config = BartConfig(
+ vocab_size=10,
+ d_model=24,
+ encoder_layers=2,
+ decoder_layers=2,
+ encoder_attention_heads=2,
+ decoder_attention_heads=2,
+ encoder_ffn_dim=32,
+ decoder_ffn_dim=32,
+ max_position_embeddings=48,
+ )
+ lm_model = TFBartForConditionalGeneration(config)
+ context = tf.fill((7, 2), 4)
+ summary = tf.fill((7, 7), 6)
+ outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
+ expected_shape = (*summary.shape, config.vocab_size)
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+
+@slow
+@require_tf
+class TFBartModelIntegrationTest(unittest.TestCase):
+ def test_inference_no_head(self):
+ model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large").model
+
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = tf.cast(tf.math.not_equal(input_ids, model.config.pad_token_id), tf.int8)
+ output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
+ expected_shape = (1, 11, 1024)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = tf.convert_to_tensor(
+ [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
+ )
+ tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-3)
+
+ def test_cnn_summarization_same_as_fairseq_hard(self):
+ hf = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
+ tok = self.tok
+
+ FRANCE_ARTICLE = ( # @noqa
+ " Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ EXPECTED_SUMMARY_FRANCE = (
+ "French prosecutor says he's not aware of any video footage from on board the plane. German daily Bild"
+ " and French Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie"
+ ' spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa confirms'
+ " co-pilot Andreas Lubitz had battled depression."
+ )
+
+ SHORTER_ARTICLE = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ EXPECTED_SUMMARY_SHORTER = (
+ "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives"
+ " the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States"
+ " opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said"
+ " it was a move toward greater justice."
+ )
+
+ # The below article tests that we don't add any hypotheses outside of the top n_beams
+ IRAN_ARTICLE = (
+ " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ EXPECTED_SUMMARY_IRAN = (
+ "The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen:"
+ " The debate that has already begun will likely result in more heat than light. He says the agreement"
+ " limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon."
+ " Bergen says the most important aim of a nuclear deal is preventing a nuclear Iran."
+ )
+
+ ARTICLE_SUBWAY = (
+ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
+ EXPECTED_SUMMARY_SUBWAY = (
+ "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the"
+ " marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in"
+ " the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly"
+ " sneaking into the subway."
+ )
+
+ dct = tok(
+ [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
+ max_length=1024,
+ truncation_strategy="only_first",
+ padding="longest",
+ truncation=True,
+ return_tensors="tf",
+ )
+ self.assertEqual(1024, dct["input_ids"].shape[1])
+ hypotheses_batch = hf.generate(
+ input_ids=dct["input_ids"],
+ attention_mask=dct["attention_mask"],
+ )
+
+ assert hypotheses_batch[:, 1].numpy().tolist() == [0, 0, 0, 0] # test force_bos_token_to_be_generated
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ expected_batch = [
+ EXPECTED_SUMMARY_FRANCE,
+ EXPECTED_SUMMARY_SHORTER,
+ EXPECTED_SUMMARY_IRAN,
+ EXPECTED_SUMMARY_SUBWAY,
+ ]
+ assert decoded == expected_batch
+
+ @cached_property
+ def tok(self):
+ return BartTokenizer.from_pretrained("facebook/bart-large")
+
+
+@slow
+@require_tf
+class FasterTFBartModelIntegrationTests(unittest.TestCase):
+ """These tests are useful for debugging since they operate on a model with 1 encoder layer and 1 decoder layer."""
+
+ @cached_property
+ def tok(self):
+ return BartTokenizer.from_pretrained("facebook/bart-large")
+
+ @cached_property
+ def xsum_1_1_model(self):
+ return TFBartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-1-1")
+
+ def test_xsum_1_1_generation(self):
+ model = self.xsum_1_1_model
+ assert model.model.decoder.embed_tokens._layer == model.model.shared
+ ARTICLE = (
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes."
+ )
+ EXPECTED = (
+ " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
+ dct = self.tok(ARTICLE, return_tensors="tf")
+ generated_ids = model.generate(**dct, num_beams=4)
+ result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ assert result == EXPECTED
+
+ def test_xsum_1_1_batch_generation(self):
+ batch = self.tok(
+ [
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
+ ],
+ return_tensors="tf",
+ padding="longest",
+ truncation=True,
+ )
+ generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
+ result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
+ assert (
+ result[0]
+ == " The International Criminal Court (ICC) has announced that it has been announced by the International"
+ " Criminal court."
+ )
+ assert (
+ result[1]
+ == " An investigation into the crash that killed at least 10 people in the French capital has been"
+ " released by the French police investigating the crash."
+ )
+
+ def test_encoder_equiv(self):
+ batch = self.tok(
+ [
+ "The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories."
+ " The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is"
+ " based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted"
+ ' its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including'
+ ' East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination'
+ " into the situation in Palestinian territories, paving the way for possible war crimes investigations"
+ " against Israelis. As members of the court, Palestinians may be subject to counter-charges as well."
+ " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts"
+ " to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony,"
+ ' said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome'
+ ' Statute today, the world is also a step closer to ending a long era of impunity and injustice," he'
+ ' said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of'
+ ' justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was'
+ ' just the first step for the Palestinians. "As the Rome Statute today enters into force for the State'
+ " of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a"
+ ' State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she'
+ ' said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize'
+ " Palestine for joining the ICC should immediately end their pressure, and countries that support"
+ " universal acceptance of the court's treaty should speak out to welcome its membership,\" said"
+ " Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts"
+ " to undermine international justice, not Palestine's decision to join a treaty to which over 100"
+ ' countries around the world are members." In January, when the preliminary ICC examination was'
+ " opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was"
+ ' overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s'
+ ' decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we'
+ ' do not believe that it is eligible to join the ICC," the State Department said in a statement. It'
+ ' urged the warring sides to resolve their differences through direct negotiations. "We will continue'
+ ' to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said.'
+ " But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows'
+ " the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor"
+ ' Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality."'
+ " The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The"
+ " inquiry will include alleged war crimes committed since June. The International Criminal Court was"
+ " set up in 2002 to prosecute genocide, crimes against humanity and war crimes.",
+ "The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted"
+ " Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor"
+ ' Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A'
+ " person who has such a video needs to immediately give it to the investigators.\" Robin's comments"
+ " follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the"
+ " French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was"
+ " recovered from a phone at the wreckage site. The two publications described the supposed video, but"
+ " did not post it on their websites. The publications said that they watched the video, which was"
+ " found by a source close to the investigation. \"One can hear cries of 'My God' in several"
+ ' languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps'
+ " of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy"
+ ' shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing'
+ " scene,\" said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident"
+ " investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc"
+ " Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the"
+ ' Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell'
+ ' phones have been collected at the site, he said, but that they "hadn\'t been exploited yet."'
+ " Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute"
+ " in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working"
+ " hand-in-hand with investigators. But none of the cell phones found so far have been sent to the"
+ " institute, Menichini said. Asked whether staff involved in the search could have leaked a memory"
+ ' card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett:'
+ ' Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are'
+ ' "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is'
+ " something we did not know before. ... Overall we can say many things of the investigation weren't"
+ ' revealed by the investigation at the beginning," he said. What was mental state of Germanwings'
+ " co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled"
+ " depression years before he took the controls of Germanwings Flight 9525, which he's accused of"
+ " deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school"
+ ' in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email'
+ " correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa"
+ " said, included medical documents he submitted in connection with resuming his flight training. The"
+ " announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle"
+ " with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa,"
+ " whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday"
+ ' as a "swift and seamless clarification" and said it was sharing the information and documents --'
+ " including training and medical records -- with public prosecutors. Spohr traveled to the crash site"
+ " Wednesday, where recovery teams have been working for the past week to recover human remains and"
+ " plane debris scattered across a steep mountainside. He saw the crisis center set up in"
+ " Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving"
+ " families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no"
+ " visible human remains were left at the site but recovery teams would keep searching. French"
+ " President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the"
+ " victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini"
+ " said. Among those personal belongings could be more cell phones belonging to the 144 passengers and"
+ " six crew on board. Check out the latest from our correspondents . The details about Lubitz's"
+ " correspondence with the flight school during his training were among several developments as"
+ " investigators continued to delve into what caused the crash and Lubitz's possible motive for"
+ " downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical"
+ ' certificate, had passed all his examinations and "held all the licenses required." Earlier, a'
+ " spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal"
+ " Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent"
+ " psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting"
+ " Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether"
+ " Lubitz feared his medical condition would cause him to lose his pilot's license, a European"
+ ' government official briefed on the investigation told CNN on Tuesday. While flying was "a big part'
+ " of his life,\" the source said, it's only one theory being considered. Another source, a law"
+ " enforcement official briefed on the investigation, also told CNN that authorities believe the"
+ " primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly"
+ " because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor"
+ " and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had"
+ " psychological issues, the European government official said. But no matter what details emerge about"
+ " his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the'
+ " fact that maybe they weren't going to keep doing their job and they're upset about that and so"
+ ' they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels'
+ " entitled to also take that rage and turn it outward on 149 other people who had nothing to do with"
+ " the person's problems.\" Germanwings crash compensation: What we know . Who was the captain of"
+ " Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from"
+ " Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff,"
+ " Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.",
+ ],
+ return_tensors="tf",
+ padding="longest",
+ truncation=True,
+ )
+ features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state
+
+ expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
+ assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
diff --git a/tests/bart/test_tokenization_bart.py b/tests/models/bart/test_tokenization_bart.py
similarity index 98%
rename from tests/bart/test_tokenization_bart.py
rename to tests/models/bart/test_tokenization_bart.py
index 66e5e0b9e3ffd6..b8e216e69ba221 100644
--- a/tests/bart/test_tokenization_bart.py
+++ b/tests/models/bart/test_tokenization_bart.py
@@ -20,7 +20,7 @@
from transformers.testing_utils import require_tokenizers, require_torch
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin, filter_roberta_detectors
+from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_detectors
@require_tokenizers
diff --git a/tests/beit/__init__.py b/tests/models/barthez/__init__.py
similarity index 100%
rename from tests/beit/__init__.py
rename to tests/models/barthez/__init__.py
diff --git a/tests/barthez/test_tokenization_barthez.py b/tests/models/barthez/test_tokenization_barthez.py
similarity index 98%
rename from tests/barthez/test_tokenization_barthez.py
rename to tests/models/barthez/test_tokenization_barthez.py
index 2738ec6e306f65..38acf046b4f3e9 100644
--- a/tests/barthez/test_tokenization_barthez.py
+++ b/tests/models/barthez/test_tokenization_barthez.py
@@ -18,7 +18,7 @@
from transformers import BarthezTokenizer, BarthezTokenizerFast, BatchEncoding
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/bert/__init__.py b/tests/models/bartpho/__init__.py
similarity index 100%
rename from tests/bert/__init__.py
rename to tests/models/bartpho/__init__.py
diff --git a/tests/bartpho/test_tokenization_bartpho.py b/tests/models/bartpho/test_tokenization_bartpho.py
similarity index 92%
rename from tests/bartpho/test_tokenization_bartpho.py
rename to tests/models/bartpho/test_tokenization_bartpho.py
index 3e35ad15c1ee54..fc5ebfd19c4a14 100644
--- a/tests/bartpho/test_tokenization_bartpho.py
+++ b/tests/models/bartpho/test_tokenization_bartpho.py
@@ -15,14 +15,14 @@
import os
import unittest
-from os.path import dirname
from transformers.models.bartpho.tokenization_bartpho import VOCAB_FILES_NAMES, BartphoTokenizer
+from transformers.testing_utils import get_tests_dir
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece_bpe.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model")
class BartphoTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/bert_generation/__init__.py b/tests/models/beit/__init__.py
similarity index 100%
rename from tests/bert_generation/__init__.py
rename to tests/models/beit/__init__.py
diff --git a/tests/beit/test_feature_extraction_beit.py b/tests/models/beit/test_feature_extraction_beit.py
similarity index 99%
rename from tests/beit/test_feature_extraction_beit.py
rename to tests/models/beit/test_feature_extraction_beit.py
index 71d78a26608162..a9338aea1fc1fc 100644
--- a/tests/beit/test_feature_extraction_beit.py
+++ b/tests/models/beit/test_feature_extraction_beit.py
@@ -22,7 +22,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py
similarity index 76%
rename from tests/beit/test_modeling_beit.py
rename to tests/models/beit/test_modeling_beit.py
index 59776f83553bfb..8c9202c34b1025 100644
--- a/tests/beit/test_modeling_beit.py
+++ b/tests/models/beit/test_modeling_beit.py
@@ -26,8 +26,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
@@ -96,9 +96,9 @@ def __init__(
self.out_indices = out_indices
self.num_labels = num_labels
- # in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -136,16 +136,14 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length - 1, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size
@@ -155,7 +153,7 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
- def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
+ def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = BeitForSemanticSegmentation(config)
model.to(torch_device)
@@ -200,8 +198,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="BEiT does not use inputs_embeds")
def test_inputs_embeds(self):
- # BEiT does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -229,9 +227,17 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- def test_for_image_segmentation(self):
+ def test_for_masked_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ def test_for_semantic_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
+ self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
def test_training(self):
if not self.model_tester.is_training:
@@ -267,13 +273,7 @@ def test_training_gradient_checkpointing(self):
or not model_class.supports_gradient_checkpointing
):
continue
- # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
- # this can then be incorporated into _prepare_for_class in test_modeling_common.py
- elif model_class.__name__ == "BeitForSemanticSegmentation":
- batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
- inputs_dict["labels"] = torch.zeros(
- [self.model_tester.batch_size, height, width], device=torch_device
- ).long()
+
model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device)
@@ -300,106 +300,6 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # BEiT has a different seq_length
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # BEiT has a different seq_length
- seq_length = self.model_tester.expected_seq_length
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
- def test_for_masked_lm(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
-
- def test_for_image_classification(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
-
@slow
def test_model_from_pretrained(self):
for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
diff --git a/tests/beit/test_modeling_flax_beit.py b/tests/models/beit/test_modeling_flax_beit.py
similarity index 73%
rename from tests/beit/test_modeling_flax_beit.py
rename to tests/models/beit/test_modeling_flax_beit.py
index 8977ab6542e041..50996dedc7af52 100644
--- a/tests/beit/test_modeling_flax_beit.py
+++ b/tests/models/beit/test_modeling_flax_beit.py
@@ -21,8 +21,8 @@
from transformers.testing_utils import require_flax, require_vision, slow
from transformers.utils import cached_property, is_flax_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor
if is_flax_available():
@@ -75,9 +75,9 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
- # in BeiT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -108,14 +108,12 @@ def create_and_check_model(self, config, pixel_values, labels):
model = FlaxBeitModel(config=config)
result = model(pixel_values)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.expected_seq_length - 1, self.vocab_size))
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length - 1, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@@ -148,51 +146,7 @@ def setUp(self) -> None:
def test_config(self):
self.config_tester.run_common_tests()
- # We need to override this test because in Beit, the seq_len equals the number of patches + 1
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- seq_length = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_length, seq_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_length, seq_length],
- )
-
- # We neeed to override this test because Beit's forward signature is different than text models.
+ # We need to override this test because Beit's forward signature is different than text models.
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -229,34 +183,6 @@ def model_jitted(pixel_values, **kwargs):
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
- # We need to override this test because in Beit, the seq_len equals the number of patches + 1
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- seq_length = self.model_tester.expected_seq_length
-
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- hidden_states = outputs.hidden_states
-
- self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
diff --git a/tests/bert_japanese/__init__.py b/tests/models/bert/__init__.py
similarity index 100%
rename from tests/bert_japanese/__init__.py
rename to tests/models/bert/__init__.py
diff --git a/tests/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py
old mode 100755
new mode 100644
similarity index 99%
rename from tests/bert/test_modeling_bert.py
rename to tests/models/bert/test_modeling_bert.py
index efef037627fa13..ca4223aacd42b5
--- a/tests/bert/test_modeling_bert.py
+++ b/tests/models/bert/test_modeling_bert.py
@@ -20,9 +20,9 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/bert/test_modeling_flax_bert.py b/tests/models/bert/test_modeling_flax_bert.py
similarity index 88%
rename from tests/bert/test_modeling_flax_bert.py
rename to tests/models/bert/test_modeling_flax_bert.py
index 0214e379010d34..5516c4d6fe67fe 100644
--- a/tests/bert/test_modeling_flax_bert.py
+++ b/tests/models/bert/test_modeling_flax_bert.py
@@ -19,7 +19,7 @@
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
@@ -114,6 +114,22 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
+ def prepare_config_and_inputs_for_decoder(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, attention_mask = config_and_inputs
+
+ config.is_decoder = True
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
diff --git a/tests/bert/test_modeling_tf_bert.py b/tests/models/bert/test_modeling_tf_bert.py
similarity index 99%
rename from tests/bert/test_modeling_tf_bert.py
rename to tests/models/bert/test_modeling_tf_bert.py
index 8c709e093801a8..e83ae9f71802d0 100644
--- a/tests/bert/test_modeling_tf_bert.py
+++ b/tests/models/bert/test_modeling_tf_bert.py
@@ -20,9 +20,9 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
-from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
if is_tf_available():
diff --git a/tests/bert/test_tokenization_bert.py b/tests/models/bert/test_tokenization_bert.py
similarity index 99%
rename from tests/bert/test_tokenization_bert.py
rename to tests/models/bert/test_tokenization_bert.py
index f53482eef75662..dfbcd266c49917 100644
--- a/tests/bert/test_tokenization_bert.py
+++ b/tests/models/bert/test_tokenization_bert.py
@@ -29,7 +29,7 @@
)
from transformers.testing_utils import require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin, filter_non_english
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
@@ -187,7 +187,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/bertweet/__init__.py b/tests/models/bert_generation/__init__.py
similarity index 100%
rename from tests/bertweet/__init__.py
rename to tests/models/bert_generation/__init__.py
diff --git a/tests/bert_generation/test_modeling_bert_generation.py b/tests/models/bert_generation/test_modeling_bert_generation.py
old mode 100755
new mode 100644
similarity index 98%
rename from tests/bert_generation/test_modeling_bert_generation.py
rename to tests/models/bert_generation/test_modeling_bert_generation.py
index 73cd77ac0f3349..f5cbd61a1d606a
--- a/tests/bert_generation/test_modeling_bert_generation.py
+++ b/tests/models/bert_generation/test_modeling_bert_generation.py
@@ -19,9 +19,9 @@
from transformers import BertGenerationConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/bert_generation/test_tokenization_bert_generation.py b/tests/models/bert_generation/test_tokenization_bert_generation.py
similarity index 94%
rename from tests/bert_generation/test_tokenization_bert_generation.py
rename to tests/models/bert_generation/test_tokenization_bert_generation.py
index d21589526bab2a..581f249db050cb 100644
--- a/tests/bert_generation/test_tokenization_bert_generation.py
+++ b/tests/models/bert_generation/test_tokenization_bert_generation.py
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import BertGenerationTokenizer
-from transformers.testing_utils import require_sentencepiece, require_torch, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
SPIECE_UNDERLINE = "ā"
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -146,7 +144,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
871,
419,
diff --git a/tests/big_bird/__init__.py b/tests/models/bert_japanese/__init__.py
similarity index 100%
rename from tests/big_bird/__init__.py
rename to tests/models/bert_japanese/__init__.py
diff --git a/tests/bert_japanese/test_tokenization_bert_japanese.py b/tests/models/bert_japanese/test_tokenization_bert_japanese.py
similarity index 97%
rename from tests/bert_japanese/test_tokenization_bert_japanese.py
rename to tests/models/bert_japanese/test_tokenization_bert_japanese.py
index 47a7d2ea036d9e..86b3f16f101e03 100644
--- a/tests/bert_japanese/test_tokenization_bert_japanese.py
+++ b/tests/models/bert_japanese/test_tokenization_bert_japanese.py
@@ -29,7 +29,7 @@
)
from transformers.testing_utils import custom_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@custom_tokenizers
@@ -176,7 +176,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "ććć«ć”ćÆ", "ćć", "ć«ć”ćÆ", "ć°ććÆ", "##ćć", "##ć«ć”ćÆ", "##ć°ććÆ"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -249,7 +249,7 @@ def test_character_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "ć", "ć", "ć«", "ć”", "ćÆ", "ć°", "äø", "ē", "ć", "ć"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -288,7 +288,8 @@ def test_tokenizer_mismatch_warning(self):
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function"
+ " is called from."
)
)
EXAMPLE_BERT_ID = "bert-base-cased"
@@ -296,6 +297,7 @@ def test_tokenizer_mismatch_warning(self):
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function"
+ " is called from."
)
)
diff --git a/tests/bigbird_pegasus/__init__.py b/tests/models/bertweet/__init__.py
similarity index 100%
rename from tests/bigbird_pegasus/__init__.py
rename to tests/models/bertweet/__init__.py
diff --git a/tests/bertweet/test_tokenization_bertweet.py b/tests/models/bertweet/test_tokenization_bertweet.py
similarity index 97%
rename from tests/bertweet/test_tokenization_bertweet.py
rename to tests/models/bertweet/test_tokenization_bertweet.py
index edeb8ae81a9d9d..5f82fba516754b 100644
--- a/tests/bertweet/test_tokenization_bertweet.py
+++ b/tests/models/bertweet/test_tokenization_bertweet.py
@@ -18,7 +18,7 @@
from transformers.models.bertweet.tokenization_bertweet import VOCAB_FILES_NAMES, BertweetTokenizer
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class BertweetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/blenderbot/__init__.py b/tests/models/big_bird/__init__.py
similarity index 100%
rename from tests/blenderbot/__init__.py
rename to tests/models/big_bird/__init__.py
diff --git a/tests/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py
similarity index 94%
rename from tests/big_bird/test_modeling_big_bird.py
rename to tests/models/big_bird/test_modeling_big_bird.py
index 24b88fd423728c..ba09241af95314 100644
--- a/tests/big_bird/test_modeling_big_bird.py
+++ b/tests/models/big_bird/test_modeling_big_bird.py
@@ -22,8 +22,8 @@
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -799,7 +799,16 @@ def test_tokenizer_inference(self):
model.to(torch_device)
text = [
- "Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer. Longformerās attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA."
+ "Transformer-based models are unable to process long sequences due to their self-attention operation,"
+ " which scales quadratically with the sequence length. To address this limitation, we introduce the"
+ " Longformer with an attention mechanism that scales linearly with sequence length, making it easy to"
+ " process documents of thousands of tokens or longer. Longformerās attention mechanism is a drop-in"
+ " replacement for the standard self-attention and combines a local windowed attention with a task"
+ " motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer"
+ " on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In"
+ " contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream"
+ " tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new"
+ " state-of-the-art results on WikiHop and TriviaQA."
]
inputs = tokenizer(text)
@@ -837,7 +846,18 @@ def test_inference_question_answering(self):
)
model.to(torch_device)
- context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."
+ context = (
+ "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and"
+ " Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago"
+ " and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a"
+ " sparse-attention based transformer which extends Transformer based models, such as BERT to much longer"
+ " sequences. In addition to sparse attention, BigBird also applies global attention as well as random"
+ " attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and"
+ " random attention approximates full attention, while being computationally much more efficient for longer"
+ " sequences. As a consequence of the capability to handle longer context, BigBird has shown improved"
+ " performance on various long document NLP tasks, such as question answering and summarization, compared"
+ " to BERT or RoBERTa."
+ )
question = [
"Which is better for longer sequences- BigBird or BERT?",
diff --git a/tests/big_bird/test_modeling_flax_big_bird.py b/tests/models/big_bird/test_modeling_flax_big_bird.py
similarity index 86%
rename from tests/big_bird/test_modeling_flax_big_bird.py
rename to tests/models/big_bird/test_modeling_flax_big_bird.py
index 5946129316b4f6..3a07996e7aed06 100644
--- a/tests/big_bird/test_modeling_flax_big_bird.py
+++ b/tests/models/big_bird/test_modeling_flax_big_bird.py
@@ -19,12 +19,13 @@
from transformers import BigBirdConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax
from transformers.models.big_bird.modeling_flax_big_bird import (
+ FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
@@ -39,17 +40,17 @@ class FlaxBigBirdModelTester(unittest.TestCase):
def __init__(
self,
parent,
- batch_size=13,
+ batch_size=2,
seq_length=56,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
- hidden_size=32,
- num_hidden_layers=5,
- num_attention_heads=4,
- intermediate_size=37,
+ hidden_size=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ intermediate_size=7,
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
@@ -61,7 +62,7 @@ def __init__(
attention_type="block_sparse",
use_bias=True,
rescale_embeddings=False,
- block_size=4,
+ block_size=2,
num_random_blocks=3,
):
self.parent = parent
@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
+ FlaxBigBirdForCausalLM,
FlaxBigBirdModel,
FlaxBigBirdForPreTraining,
FlaxBigBirdForMaskedLM,
@@ -154,10 +156,30 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FlaxBigBirdModelTester(self)
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_from_pretrained_save_pretrained(self):
+ super().test_from_pretrained_save_pretrained()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_from_pretrained_with_no_automatic_init(self):
+ super().test_from_pretrained_with_no_automatic_init()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_no_automatic_init(self):
+ super().test_no_automatic_init()
+
+ @slow
+ # copied from `test_modeling_flax_common` because it takes much longer than other models
+ def test_hidden_states_output(self):
+ super().test_hidden_states_output()
+
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
- model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True)
+ model = model_class_name.from_pretrained("google/bigbird-roberta-base")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
diff --git a/tests/big_bird/test_tokenization_big_bird.py b/tests/models/big_bird/test_tokenization_big_bird.py
similarity index 95%
rename from tests/big_bird/test_tokenization_big_bird.py
rename to tests/models/big_bird/test_tokenization_big_bird.py
index d1d0daeda7e934..ff654510082574 100644
--- a/tests/big_bird/test_tokenization_big_bird.py
+++ b/tests/models/big_bird/test_tokenization_big_bird.py
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import BigBirdTokenizer, BigBirdTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
SPIECE_UNDERLINE = "ā"
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -170,7 +168,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
# fmt: off
original_tokenizer_encodings = [65, 871, 419, 358, 946, 991, 2521, 452, 358, 1357, 387, 7751, 3536, 112, 985, 456, 126, 865, 938, 5400, 5734, 458, 1368, 467, 786, 2462, 5246, 1159, 633, 865, 4519, 457, 582, 852, 2557, 427, 916, 508, 405, 34324, 497, 391, 408, 11342, 1244, 385, 100, 938, 985, 456, 574, 362, 12597, 3200, 3129, 1172, 66] # noqa: E231
# fmt: on
diff --git a/tests/blenderbot_small/__init__.py b/tests/models/bigbird_pegasus/__init__.py
similarity index 100%
rename from tests/blenderbot_small/__init__.py
rename to tests/models/bigbird_pegasus/__init__.py
diff --git a/tests/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
similarity index 98%
rename from tests/bigbird_pegasus/test_modeling_bigbird_pegasus.py
rename to tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
index eebdb0a91c11fe..d4e7e8f4ae422a 100644
--- a/tests/bigbird_pegasus/test_modeling_bigbird_pegasus.py
+++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py
@@ -22,9 +22,9 @@
from transformers import BigBirdPegasusConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -538,9 +538,26 @@ def test_seq_to_seq_generation(self):
hypotheses_batch = model.generate(**inputs)
- EXPECTED_LEP = "motivated by some recent studies on the light cp - odd higgs boson @xmath0 in non - minimal supersymmetric models, we investigate the rare @xmath1-decays @xmath2 ( @xmath3 ) in the two higgs doublet model ( 2hdm ), the nearly minimal supersymmetric standard model ( nmssm ), the next - to - minimal supersymmetric standard model ( nmssm ) and the minimal supersymmetric standard model ( mssm ). we find that the branching ratios of @xmath4 can reach @xmath5 in 2hdm, @xmath6 in nmssm and @xmath7 in mssm, which are at the level of @xmath8 in 2hdm, @xmath9 in nmssm and @xmath10 in mssm, respectively. these rates can be significantly enhanced in new physics models which lie within the expected sensitivity of the gigaz option of the international linear collider ( ilc ). = # 1,nucl. phys. b * # 1"
+ EXPECTED_LEP = (
+ "motivated by some recent studies on the light cp - odd higgs boson @xmath0 in non - minimal"
+ " supersymmetric models, we investigate the rare @xmath1-decays @xmath2 ( @xmath3 ) in the two higgs"
+ " doublet model ( 2hdm ), the nearly minimal supersymmetric standard model ( nmssm ), the next - to -"
+ " minimal supersymmetric standard model ( nmssm ) and the minimal supersymmetric standard model ( mssm"
+ " ). we find that the branching ratios of @xmath4 can reach @xmath5 in 2hdm, @xmath6 in nmssm and"
+ " @xmath7 in mssm, which are at the level of @xmath8 in 2hdm, @xmath9 in nmssm and @xmath10 in mssm,"
+ " respectively. these rates can be significantly enhanced in new physics models which lie within the"
+ " expected sensitivity of the gigaz option of the international linear collider ( ilc ). = # 1,nucl."
+ " phys. b * # 1"
+ )
- EXPECTED_MAGNET = "a positive, nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic - field range in the surface state of a topological insulator having a positive and finite effective g - factor. this linear magnetoresistance shows up in the system of high carrier concentration and low mobility when electrons are in extended states and spread over many smeared landau levels, and persists up to room temperature, providing a possible mechanism for the recently observed linear magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons."
+ EXPECTED_MAGNET = (
+ "a positive, nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic -"
+ " field range in the surface state of a topological insulator having a positive and finite effective g -"
+ " factor. this linear magnetoresistance shows up in the system of high carrier concentration and low"
+ " mobility when electrons are in extended states and spread over many smeared landau levels, and persists"
+ " up to room temperature, providing a possible mechanism for the recently observed linear"
+ " magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons."
+ )
generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
diff --git a/tests/bort/__init__.py b/tests/models/blenderbot/__init__.py
similarity index 100%
rename from tests/bort/__init__.py
rename to tests/models/blenderbot/__init__.py
diff --git a/tests/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py
similarity index 97%
rename from tests/blenderbot/test_modeling_blenderbot.py
rename to tests/models/blenderbot/test_modeling_blenderbot.py
index 4a8070bfa79c61..ee76626ffed604 100644
--- a/tests/blenderbot/test_modeling_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_blenderbot.py
@@ -21,9 +21,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -140,6 +140,7 @@ def get_config(self):
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
+ config.vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@@ -218,6 +219,7 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
all_model_classes = (BlenderbotModel, BlenderbotForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -304,7 +306,10 @@ def test_generation_from_short_input_same_as_parlai_3B(self):
generated_txt = self.tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
assert generated_txt[0].strip() == tgt_text
- src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
+ src_text = (
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel"
+ " like i'm going to throw up.\nand why is that?"
+ )
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
diff --git a/tests/blenderbot/test_modeling_flax_blenderbot.py b/tests/models/blenderbot/test_modeling_flax_blenderbot.py
similarity index 99%
rename from tests/blenderbot/test_modeling_flax_blenderbot.py
rename to tests/models/blenderbot/test_modeling_flax_blenderbot.py
index cf6b8b9083b21f..fad60bcced9d45 100644
--- a/tests/blenderbot/test_modeling_flax_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_flax_blenderbot.py
@@ -20,8 +20,8 @@
from transformers import BlenderbotConfig, is_flax_available
from transformers.testing_utils import jax_device, require_flax, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/tests/blenderbot/test_modeling_tf_blenderbot.py b/tests/models/blenderbot/test_modeling_tf_blenderbot.py
similarity index 99%
rename from tests/blenderbot/test_modeling_tf_blenderbot.py
rename to tests/models/blenderbot/test_modeling_tf_blenderbot.py
index e9e9816b17824b..a8ca54558f06ac 100644
--- a/tests/blenderbot/test_modeling_tf_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_tf_blenderbot.py
@@ -20,8 +20,8 @@
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
diff --git a/tests/blenderbot/test_tokenization_blenderbot.py b/tests/models/blenderbot/test_tokenization_blenderbot.py
similarity index 100%
rename from tests/blenderbot/test_tokenization_blenderbot.py
rename to tests/models/blenderbot/test_tokenization_blenderbot.py
diff --git a/tests/byt5/__init__.py b/tests/models/blenderbot_small/__init__.py
similarity index 100%
rename from tests/byt5/__init__.py
rename to tests/models/blenderbot_small/__init__.py
diff --git a/tests/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
similarity index 98%
rename from tests/blenderbot_small/test_modeling_blenderbot_small.py
rename to tests/models/blenderbot_small/test_modeling_blenderbot_small.py
index 7e8ce4b624e0dd..47503b9c7f3653 100644
--- a/tests/blenderbot_small/test_modeling_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
@@ -21,9 +21,9 @@
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -290,8 +291,8 @@ def tokenizer(self):
def test_90_generation_from_long_input(self):
src_text = [
- "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like\
- i'm going to throw up.\nand why is that?"
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel"
+ " like i'm going to throw up.\nand why is that?"
]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
diff --git a/tests/blenderbot_small/test_modeling_flax_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_flax_blenderbot_small.py
similarity index 99%
rename from tests/blenderbot_small/test_modeling_flax_blenderbot_small.py
rename to tests/models/blenderbot_small/test_modeling_flax_blenderbot_small.py
index 6f674624265f47..3cbacfc8d89237 100644
--- a/tests/blenderbot_small/test_modeling_flax_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_flax_blenderbot_small.py
@@ -20,8 +20,8 @@
from transformers import BlenderbotSmallConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
similarity index 98%
rename from tests/blenderbot_small/test_modeling_tf_blenderbot_small.py
rename to tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
index 33c5b286c9cf73..f8543aad59d85c 100644
--- a/tests/blenderbot_small/test_modeling_tf_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py
@@ -20,8 +20,8 @@
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
@@ -305,7 +305,8 @@ def _long_tensor(tok_lst):
@require_tf
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
src_text = [
- "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
+ "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like "
+ " i'm going to throw up.\nand why is that?"
]
model_name = "facebook/blenderbot_small-90M"
diff --git a/tests/blenderbot_small/test_tokenization_blenderbot_small.py b/tests/models/blenderbot_small/test_tokenization_blenderbot_small.py
similarity index 98%
rename from tests/blenderbot_small/test_tokenization_blenderbot_small.py
rename to tests/models/blenderbot_small/test_tokenization_blenderbot_small.py
index 38c3f8391d2276..7ea7f09b5764bf 100644
--- a/tests/blenderbot_small/test_tokenization_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_tokenization_blenderbot_small.py
@@ -23,7 +23,7 @@
BlenderbotSmallTokenizer,
)
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/camembert/__init__.py b/tests/models/bloom/__init__.py
similarity index 100%
rename from tests/camembert/__init__.py
rename to tests/models/bloom/__init__.py
diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py
new file mode 100644
index 00000000000000..f71618eae8454f
--- /dev/null
+++ b/tests/models/bloom/test_modeling_bloom.py
@@ -0,0 +1,757 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+import unittest
+
+from transformers import BloomConfig, is_torch_available
+from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomTokenizerFast,
+ )
+
+
+@require_torch
+class BloomModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=14,
+ seq_length=7,
+ is_training=True,
+ use_token_type_ids=False,
+ use_input_mask=True,
+ use_labels=True,
+ use_mc_token_ids=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_token_type_ids = use_token_type_ids
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.use_mc_token_ids = use_mc_token_ids
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = None
+ self.bos_token_id = vocab_size - 1
+ self.eos_token_id = vocab_size - 1
+ self.pad_token_id = vocab_size - 1
+
+ def get_large_model_config(self):
+ return BloomConfig.from_pretrained("bigscience/bloom")
+
+ def prepare_config_and_inputs(self, gradient_checkpointing=False):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ sequence_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config(gradient_checkpointing=gradient_checkpointing)
+
+ return (config, input_ids, input_mask, sequence_labels)
+
+ def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
+ return BloomConfig(
+ vocab_size=self.vocab_size,
+ seq_length=self.seq_length,
+ hidden_size=self.hidden_size,
+ n_layer=self.num_hidden_layers,
+ n_head=self.num_attention_heads,
+ resid_pdrop=self.hidden_dropout_prob,
+ attn_pdrop=self.attention_probs_dropout_prob,
+ n_positions=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ use_cache=True,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ pad_token_id=self.pad_token_id,
+ num_labels=self.num_labels,
+ gradient_checkpointing=gradient_checkpointing,
+ slow_but_exact=slow_but_exact,
+ dtype="float32",
+ )
+
+ def create_and_check_bloom_model(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(len(result.past_key_values), config.n_layer)
+
+ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True)
+ outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids))
+ outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids))
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ past = outputs["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+ half_seq_length = self.seq_length // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
+
+ output, past = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
+ "last_hidden_state"
+ ]
+ self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
+ config.num_labels = self.num_labels
+ model = BloomForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
+ model = BloomForTokenClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
+
+ def create_and_check_forward_and_backwards(
+ self, config, input_ids, input_mask, *args, gradient_checkpointing=False
+ ):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ if gradient_checkpointing:
+ model.gradient_checkpointing_enable()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ result.loss.backward()
+
+ def create_and_check_bloom_weight_initialization(self, config, *args):
+ model = BloomModel(config)
+ model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)
+ for key in model.state_dict().keys():
+ if "c_proj" in key and "weight" in key:
+ self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
+ self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+
+ config, input_ids, input_mask, sequence_labels = config_and_inputs
+
+ inputs_dict = {"input_ids": input_ids}
+
+ return config, inputs_dict
+
+
+@require_torch
+class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ BloomModel,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
+ fx_compatible = False
+ test_missing_keys = False
+ test_pruning = False
+ test_torchscript = True # torch.autograd functions seems to be not supported
+
+ def setUp(self):
+ self.model_tester = BloomModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=BloomConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_bloom_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model(*config_and_inputs)
+
+ def test_bloom_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past(*config_and_inputs)
+
+ def test_bloom_model_att_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_attention_mask_past(*config_and_inputs)
+
+ def test_bloom_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past_large_inputs(*config_and_inputs)
+
+ def test_bloom_lm_head_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
+
+ def test_bloom_sequence_classification_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
+
+ def test_bloom_token_classification_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
+
+ def test_bloom_gradient_checkpointing(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
+
+ def test_bloom_weight_initialization(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = BloomModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ @slow
+ @require_torch_gpu
+ def test_simple_generation(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
+
+ input_sentence = "I enjoy walking with my cute dog"
+ EXPECTED_OUTPUT = (
+ "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am"
+ " a very good listener. I am a very good person, and I am a very good person. I am a"
+ )
+
+ input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
+ greedy_output = model.generate(input_ids.cuda(), max_length=50)
+
+ self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+
+ self.assertEqual(
+ tokenizer.decode(greedy_output[0], skip_special_tokens=True),
+ tokenizer.decode(greedy_output[1], skip_special_tokens=True),
+ )
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation_padd(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "Hello my name is"]
+ input_sentence_without_pad = "Hello my name is"
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ input_ids_without_pad = tokenizer.encode(input_sentence_without_pad, return_tensors="pt")
+
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+ greedy_output_without_pad = model.generate(input_ids_without_pad.cuda(), max_length=50, do_sample=False)
+
+ # test token values
+ self.assertEqual(greedy_output[-1, 3:].tolist(), greedy_output_without_pad[0, :-3].tolist())
+
+ # test reconstructions
+ self.assertEqual(
+ tokenizer.decode(greedy_output[-1, 3:], skip_special_tokens=True),
+ tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True),
+ )
+
+
+@require_torch
+class BloomEmbeddingTest(unittest.TestCase):
+ """
+ The goal here is to compare the embeddings generated by the model trained
+ using Megatron-LM with the one from the transformers library, with a small GPT2-like model
+ to ensure that the conversion from Megatron-LM to transformers has been done successfully.
+ The script compares the logits of the embedding layer and the transformer layers.
+
+ WARNING: It is expected that these logits will not have exactly the same statistics when running
+ the code on CPU or GPU. For more info, please visit:
+ - https://github.com/pytorch/pytorch/issues/76052#issuecomment-1103193548
+ - https://discuss.pytorch.org/t/reproducibility-issue-between-intel-and-amd-cpus/144779/9
+
+
+ You need to install tokenizers following this readme:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ Tokenizer used during training:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ # TODO change the script (or just add skip) when building the env with tokenizers 0.12.0
+ """
+
+ def setUp(self):
+ super().setUp()
+ self.path_bigscience_model = "bigscience/bigscience-small-testing"
+
+ @require_torch
+ def test_embeddings(self):
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, torch_dtype="auto") # load in fp32
+ model.eval()
+
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN = {
+ 3478: 0.0002307891845703125,
+ 368: -0.000568389892578125,
+ 109586: -0.0003910064697265625,
+ 35433: -0.000194549560546875,
+ 2: 0.0004138946533203125,
+ 77: 0.000659942626953125,
+ 132619: -0.00031280517578125,
+ 2175: 0.000457763671875,
+ 23714: 0.000263214111328125,
+ 73173: -0.000286102294921875,
+ 144252: 0.00052642822265625,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM = {"value": 0.08203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN = {
+ 132619: -0.00031256675720214844,
+ 3478: 0.00023090839385986328,
+ 368: -0.0005702972412109375,
+ 109586: -0.00039124488830566406,
+ 35433: -0.000194549560546875,
+ 2: 0.0004146099090576172,
+ 2175: 0.0004572868347167969,
+ 23714: 0.00026416778564453125,
+ 73173: -0.0002865791320800781,
+ 144252: 0.0005254745483398438,
+ 77: 0.0006618499755859375,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_SUM = {"value": 0.0821533203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN = {
+ 132619: -0.00031267106533050537,
+ 3478: 0.00023087859153747559,
+ 368: -0.0005701072514057159,
+ 109586: -0.0003911703824996948,
+ 35433: -0.0001944899559020996,
+ 2: 0.0004146844148635864,
+ 2175: 0.00045740045607089996,
+ 23714: 0.0002641640603542328,
+ 73173: -0.0002864748239517212,
+ 144252: 0.0005256589502096176,
+ 77: 0.0006617321632802486,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_SUM = {"value": 0.08217757940292358}
+
+ TEST_EMBEDDINGS = {
+ "torch.bfloat16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM,
+ },
+ "torch.float32": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_16_SUM,
+ },
+ }
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ EMBEDDINGS_DS_AFTER_LN_MEAN = {
+ 3478: -6.580352783203125e-05,
+ 368: 0.0001316070556640625,
+ 109586: -0.00030517578125,
+ 35433: 4.00543212890625e-05,
+ 2: -7.2479248046875e-05,
+ 77: -8.96453857421875e-05,
+ 132619: 0.0001583099365234375,
+ 2175: 2.1219253540039062e-05,
+ 23714: -0.000247955322265625,
+ 73173: -0.00021839141845703125,
+ 144252: -0.0001430511474609375,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MIN = {
+ 3478: -1.6953125,
+ 368: -1.6875,
+ 109586: -1.6875,
+ 35433: -2.125,
+ 2: -1.390625,
+ 77: -1.5390625,
+ 132619: -1.875,
+ 2175: -1.4609375,
+ 23714: -2.296875,
+ 73173: -1.3515625,
+ 144252: -1.78125,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MAX = {
+ 3478: 2.265625,
+ 368: 2.28125,
+ 109586: 1.953125,
+ 35433: 1.90625,
+ 2: 2.703125,
+ 77: 2.828125,
+ 132619: 1.65625,
+ 2175: 2.015625,
+ 23714: 2.234375,
+ 73173: 2.171875,
+ 144252: 1.828125,
+ }
+
+ EMBEDDINGS_DS_AFTER_LN = {
+ "mean": EMBEDDINGS_DS_AFTER_LN_MEAN,
+ "min": EMBEDDINGS_DS_AFTER_LN_MIN,
+ "max": EMBEDDINGS_DS_AFTER_LN_MAX,
+ }
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+ with torch.no_grad():
+ embeddings = model.transformer.word_embeddings(tensor_ids)
+ embeddings_ln = model.transformer.word_embeddings_layernorm(embeddings) #
+ # first check the embeddings before LN
+ output_dict = {"min": {}, "max": {}, "mean": {}, "sum": {"value": embeddings.sum().item()}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict["min"][idx] = embeddings.min(dim=-1).values[0][i].item()
+ output_dict["max"][idx] = embeddings.max(dim=-1).values[0][i].item()
+ output_dict["mean"][idx] = embeddings.mean(dim=-1)[0][i].item()
+
+ for key in TEST_EMBEDDINGS[str(model.dtype)].keys():
+ self.assertDictEqual(TEST_EMBEDDINGS[str(model.dtype)][key], output_dict[key])
+
+ output_dict_norm = {"min": {}, "max": {}, "mean": {}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict_norm["min"][idx] = embeddings_ln.min(dim=-1).values[0][i].item()
+ output_dict_norm["max"][idx] = embeddings_ln.max(dim=-1).values[0][i].item()
+ output_dict_norm["mean"][idx] = embeddings_ln.mean(dim=-1)[0][i].item()
+
+ # This test does not pass when places = 2
+ for i, key in enumerate(output_dict_norm.keys()):
+ for j, idx in enumerate(output_dict[key].keys()):
+ self.assertAlmostEqual(EMBEDDINGS_DS_AFTER_LN[key][idx], output_dict_norm[key][idx], places=1)
+
+ @require_torch
+ def test_hidden_states_transformers(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ )
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_VALUE_LAST_LM = -4.3392181396484375e-05
+ MIN_MAX_DICT = {"min": -2.0625, "max": 2.75}
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+
+ with torch.no_grad():
+ logits = model(tensor_ids.to(torch_device))
+ output_dict = {
+ "min": logits.last_hidden_state.min(dim=-1).values[0][0].item(),
+ "max": logits.last_hidden_state.max(dim=-1).values[0][0].item(),
+ }
+
+ if cuda_available:
+ self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4)
+ else:
+ self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3)
+
+ self.assertDictEqual(MIN_MAX_DICT, output_dict)
+
+ @require_torch
+ def test_logits(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ ) # load in bf16
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_LOGITS_GPU_1 = -1.823902130126953e-05
+ MEAN_LOGITS_GPU_2 = 1.9431114196777344e-05
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS]).to(torch_device)
+ with torch.no_grad():
+ output = model(tensor_ids).logits
+
+ output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
+ if cuda_available:
+ self.assertEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1)
+ self.assertEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2)
+ else:
+ self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
+ self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
diff --git a/tests/models/bloom/test_tokenization_bloom.py b/tests/models/bloom/test_tokenization_bloom.py
new file mode 100644
index 00000000000000..c213437a37dd00
--- /dev/null
+++ b/tests/models/bloom/test_tokenization_bloom.py
@@ -0,0 +1,129 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from datasets import load_dataset
+
+from transformers import BloomTokenizerFast
+from transformers.testing_utils import require_tokenizers
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+@require_tokenizers
+class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ slow_tokenizer_class = None
+ rust_tokenizer_class = BloomTokenizerFast
+ tokenizer_class = BloomTokenizerFast
+ test_rust_tokenizer = True
+ test_slow_tokenizer = False
+ from_pretrained_vocab_key = "tokenizer_file"
+ special_tokens_map = {"bos_token": "", "eos_token": "", "unk_token": "", "pad_token": ""}
+
+ def setUp(self):
+ super().setUp()
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/tokenizer")
+ tokenizer.save_pretrained(self.tmpdirname)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return BloomTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def test_encodings_from_sample_data(self):
+ """
+ Assert that the created tokens are the same than the hard-coded ones
+ """
+ tokenizer = self.get_rust_tokenizer()
+
+ INPUT_SENTENCES = ["The quick brown fox", "jumps over the lazy dog"]
+ TARGET_TOKENS = [[2175, 23714, 73173, 144252, 2], [77, 132619, 3478, 368, 109586, 35433, 2]]
+
+ computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"]
+ self.assertListEqual(TARGET_TOKENS, computed_tokens)
+
+ decoded_tokens = tokenizer.batch_decode(computed_tokens)
+ self.assertListEqual(decoded_tokens, INPUT_SENTENCES)
+
+ def test_padding(self, max_length=6):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ # tokenizer_r.pad_token = None # Hotfixing padding = None
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ]
+
+ # Simple input tests
+ try:
+ tokenizer_r.encode(s, max_length=max_length)
+ tokenizer_r.encode_plus(s, max_length=max_length)
+
+ tokenizer_r.batch_encode_plus(s2, max_length=max_length)
+ tokenizer_r.encode(p, max_length=max_length)
+ tokenizer_r.batch_encode_plus(p2, max_length=max_length)
+ except ValueError:
+ self.fail("Bloom Tokenizer should be able to deal with padding")
+
+ tokenizer_r.pad_token = None # Hotfixing padding = None
+ self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ s2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ p2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ def test_encodings_from_xnli_dataset(self):
+ """
+ Tests the tokenizer downloaded from here:
+ - https://huggingface.co/bigscience/tokenizer/
+ """
+ tokenizer = self.get_rust_tokenizer()
+ ds = load_dataset("xnli", "all_languages", split="test", streaming=True)
+
+ sample_data = next(iter(ds))["premise"] # pick up one data
+ input_text = list(sample_data.values())
+
+ output_tokens = list(map(tokenizer.encode, input_text))
+ predicted_text = list(map(lambda x: tokenizer.decode(x, clean_up_tokenization_spaces=False), output_tokens))
+ self.assertListEqual(predicted_text, input_text)
diff --git a/tests/canine/__init__.py b/tests/models/bort/__init__.py
similarity index 100%
rename from tests/canine/__init__.py
rename to tests/models/bort/__init__.py
diff --git a/tests/bort/test_modeling_bort.py b/tests/models/bort/test_modeling_bort.py
similarity index 100%
rename from tests/bort/test_modeling_bort.py
rename to tests/models/bort/test_modeling_bort.py
diff --git a/tests/bort/test_modeling_tf_bort.py b/tests/models/bort/test_modeling_tf_bort.py
similarity index 100%
rename from tests/bort/test_modeling_tf_bort.py
rename to tests/models/bort/test_modeling_tf_bort.py
diff --git a/tests/clip/__init__.py b/tests/models/byt5/__init__.py
similarity index 100%
rename from tests/clip/__init__.py
rename to tests/models/byt5/__init__.py
diff --git a/tests/byt5/test_tokenization_byt5.py b/tests/models/byt5/test_tokenization_byt5.py
similarity index 99%
rename from tests/byt5/test_tokenization_byt5.py
rename to tests/models/byt5/test_tokenization_byt5.py
index 7e4f97d3741b6f..70cfa40ef9196c 100644
--- a/tests/byt5/test_tokenization_byt5.py
+++ b/tests/models/byt5/test_tokenization_byt5.py
@@ -24,7 +24,7 @@
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
from transformers.utils import cached_property, is_tf_available, is_torch_available
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
if is_torch_available():
diff --git a/tests/convbert/__init__.py b/tests/models/camembert/__init__.py
similarity index 100%
rename from tests/convbert/__init__.py
rename to tests/models/camembert/__init__.py
diff --git a/tests/camembert/test_modeling_camembert.py b/tests/models/camembert/test_modeling_camembert.py
similarity index 100%
rename from tests/camembert/test_modeling_camembert.py
rename to tests/models/camembert/test_modeling_camembert.py
diff --git a/tests/camembert/test_modeling_tf_camembert.py b/tests/models/camembert/test_modeling_tf_camembert.py
similarity index 100%
rename from tests/camembert/test_modeling_tf_camembert.py
rename to tests/models/camembert/test_modeling_tf_camembert.py
diff --git a/tests/camembert/test_tokenization_camembert.py b/tests/models/camembert/test_tokenization_camembert.py
similarity index 93%
rename from tests/camembert/test_tokenization_camembert.py
rename to tests/models/camembert/test_tokenization_camembert.py
index a2274e3b05bb6d..aff186d73cb065 100644
--- a/tests/camembert/test_tokenization_camembert.py
+++ b/tests/models/camembert/test_tokenization_camembert.py
@@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import CamembertTokenizer, CamembertTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import is_torch_available
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
-SAMPLE_BPE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece_bpe.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
+SAMPLE_BPE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model")
FRAMEWORK = "pt" if is_torch_available() else "tf"
diff --git a/tests/convnext/__init__.py b/tests/models/canine/__init__.py
similarity index 100%
rename from tests/convnext/__init__.py
rename to tests/models/canine/__init__.py
diff --git a/tests/canine/test_modeling_canine.py b/tests/models/canine/test_modeling_canine.py
similarity index 98%
rename from tests/canine/test_modeling_canine.py
rename to tests/models/canine/test_modeling_canine.py
index 5e3c37b37edede..a4d13f0efab6c4 100644
--- a/tests/canine/test_modeling_canine.py
+++ b/tests/models/canine/test_modeling_canine.py
@@ -21,8 +21,8 @@
from transformers import CanineConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, global_rng, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, global_rng, ids_tensor, random_attention_mask
if is_torch_available():
@@ -378,7 +378,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/canine/test_tokenization_canine.py b/tests/models/canine/test_tokenization_canine.py
similarity index 99%
rename from tests/canine/test_tokenization_canine.py
rename to tests/models/canine/test_tokenization_canine.py
index 0a949e6d78fdda..0e016d523b5cb9 100644
--- a/tests/canine/test_tokenization_canine.py
+++ b/tests/models/canine/test_tokenization_canine.py
@@ -24,7 +24,7 @@
from transformers.tokenization_utils import AddedToken
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/cpm/__init__.py b/tests/models/clip/__init__.py
similarity index 100%
rename from tests/cpm/__init__.py
rename to tests/models/clip/__init__.py
diff --git a/tests/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py
similarity index 75%
rename from tests/clip/test_feature_extraction_clip.py
rename to tests/models/clip/test_feature_extraction_clip.py
index 85cb54510c8c51..8f36a65ae2d596 100644
--- a/tests/clip/test_feature_extraction_clip.py
+++ b/tests/models/clip/test_feature_extraction_clip.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin
if is_torch_available():
@@ -49,6 +49,7 @@ def __init__(
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
+ do_convert_rgb=True,
):
self.parent = parent
self.batch_size = batch_size
@@ -63,6 +64,7 @@ def __init__(
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
+ self.do_convert_rgb = do_convert_rgb
def prepare_feat_extract_dict(self):
return {
@@ -73,6 +75,7 @@ def prepare_feat_extract_dict(self):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
}
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
@@ -128,6 +131,7 @@ def test_feat_extract_properties(self):
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_batch_feature(self):
pass
@@ -227,3 +231,64 @@ def test_call_pytorch(self):
self.feature_extract_tester.crop_size,
),
)
+
+
+@require_torch
+@require_vision
+class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4)
+ self.expected_encoded_image_num_channels = 3
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil_four_channels(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.expected_encoded_image_num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.expected_encoded_image_num_channels,
+ self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size,
+ ),
+ )
diff --git a/tests/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py
similarity index 83%
rename from tests/clip/test_modeling_clip.py
rename to tests/models/clip/test_modeling_clip.py
index 88649b31455ceb..ab05f9adf1e870 100644
--- a/tests/clip/test_modeling_clip.py
+++ b/tests/models/clip/test_modeling_clip.py
@@ -35,8 +35,8 @@
)
from transformers.utils import is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -100,6 +100,10 @@ def __init__(
self.initializer_range = initializer_range
self.scope = scope
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
@@ -148,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
-
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
@@ -160,8 +164,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="CLIP does not use inputs_embeds")
def test_inputs_embeds(self):
- # CLIP does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -189,114 +193,17 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- image_size = (self.model_tester.image_size, self.model_tester.image_size)
- patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_len = num_patches + 1
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # CLIP has a different seq_length
- image_size = (self.model_tester.image_size, self.model_tester.image_size)
- patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- seq_length = num_patches + 1
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
- # skip this test as CLIPVisionModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
- # skip this test as CLIPVisionModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@@ -396,6 +303,7 @@ def prepare_config_and_inputs_for_common(self):
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_head_masking = False
@@ -416,17 +324,15 @@ def test_training(self):
def test_training_gradient_checkpointing(self):
pass
+ @unittest.skip(reason="CLIP does not use inputs_embeds")
def test_inputs_embeds(self):
- # CLIP does not use inputs_embeds
pass
- # skip this test as CLIPTextModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
- # skip this test as CLIPTextModel has no base class and is
- # not available in MODEL_MAPPING
+ @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@@ -483,6 +389,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
+ fx_compatible = True
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
@@ -495,19 +402,19 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- # hidden_states are tested in individual model tests
+ @unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
- # input_embeds are tested in individual model tests
+ @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
- # tested in individual model tests
+ @unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
- # CLIPModel does not have input/output embeddings
+ @unittest.skip(reason="CLIPModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
diff --git a/tests/clip/test_modeling_flax_clip.py b/tests/models/clip/test_modeling_flax_clip.py
similarity index 99%
rename from tests/clip/test_modeling_flax_clip.py
rename to tests/models/clip/test_modeling_flax_clip.py
index adad20befa7234..b8a1030ad1b0c8 100644
--- a/tests/clip/test_modeling_flax_clip.py
+++ b/tests/models/clip/test_modeling_flax_clip.py
@@ -8,7 +8,7 @@
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/clip/test_modeling_tf_clip.py b/tests/models/clip/test_modeling_tf_clip.py
similarity index 80%
rename from tests/clip/test_modeling_tf_clip.py
rename to tests/models/clip/test_modeling_tf_clip.py
index 7ba93524062217..797d5b73b3493a 100644
--- a/tests/clip/test_modeling_tf_clip.py
+++ b/tests/models/clip/test_modeling_tf_clip.py
@@ -26,8 +26,8 @@
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import is_tf_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
@@ -256,6 +256,62 @@ def test_model_from_pretrained(self):
model = TFCLIPVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @slow
+ def test_saved_model_creation_extended(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ if hasattr(config, "use_cache"):
+ config.use_cache = True
+
+ # in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+
+ for model_class in self.all_model_classes:
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+ num_out = len(model(class_inputs_dict))
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname, saved_model=True)
+ saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
+ model = tf.keras.models.load_model(saved_model_dir)
+ outputs = model(class_inputs_dict)
+ output_hidden_states = outputs["hidden_states"]
+ output_attentions = outputs["attentions"]
+
+ # Check num outputs
+ self.assertEqual(len(outputs), num_out)
+
+ # Check num layers
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+
+ self.assertEqual(len(output_hidden_states), expected_num_layers)
+ self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
+
+ # Check attention outputs
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+
+ self.assertListEqual(
+ list(output_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ # Check hidden states
+ self.assertListEqual(
+ list(output_hidden_states[0].shape[-2:]),
+ [seq_len, self.model_tester.hidden_size],
+ )
+
class TFCLIPTextModelTester:
def __init__(
@@ -367,6 +423,54 @@ def test_model_from_pretrained(self):
model = TFCLIPTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @slow
+ def test_saved_model_creation_extended(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ if hasattr(config, "use_cache"):
+ config.use_cache = True
+
+ for model_class in self.all_model_classes:
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+ num_out = len(model(class_inputs_dict))
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname, saved_model=True)
+ saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
+ model = tf.keras.models.load_model(saved_model_dir)
+ outputs = model(class_inputs_dict)
+ output_hidden_states = outputs["hidden_states"]
+ output_attentions = outputs["attentions"]
+
+ # Check number of outputs
+ self.assertEqual(len(outputs), num_out)
+
+ # Check number of layers
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+
+ # Check hidden states
+ self.assertEqual(len(output_hidden_states), expected_num_layers)
+ self.assertListEqual(
+ list(output_hidden_states[0].shape[-2:]),
+ [self.model_tester.seq_length, self.model_tester.hidden_size],
+ )
+
+ # Check attention outputs
+ self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
+
+ seq_length = self.model_tester.seq_length
+ key_length = getattr(self.model_tester, "key_length", seq_length)
+
+ self.assertListEqual(
+ list(output_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_length, key_length],
+ )
+
class TFCLIPModelTester:
def __init__(self, parent, is_training=True):
@@ -502,6 +606,11 @@ def test_model_from_pretrained(self):
model = TFCLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
+ @slow
+ def test_saved_model_creation_extended(self):
+ pass
+
# We will verify our results on an image of cute cats
def prepare_img():
diff --git a/tests/clip/test_processor_clip.py b/tests/models/clip/test_processor_clip.py
similarity index 100%
rename from tests/clip/test_processor_clip.py
rename to tests/models/clip/test_processor_clip.py
diff --git a/tests/clip/test_tokenization_clip.py b/tests/models/clip/test_tokenization_clip.py
similarity index 99%
rename from tests/clip/test_tokenization_clip.py
rename to tests/models/clip/test_tokenization_clip.py
index 2ad48ca710a190..e9ba304b475dd6 100644
--- a/tests/clip/test_tokenization_clip.py
+++ b/tests/models/clip/test_tokenization_clip.py
@@ -22,7 +22,7 @@
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
from transformers.testing_utils import require_ftfy, require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/ctrl/__init__.py b/tests/models/convbert/__init__.py
similarity index 100%
rename from tests/ctrl/__init__.py
rename to tests/models/convbert/__init__.py
diff --git a/tests/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py
similarity index 99%
rename from tests/convbert/test_modeling_convbert.py
rename to tests/models/convbert/test_modeling_convbert.py
index a6b41b02d2da4a..d3eb0aec4cfc0d 100644
--- a/tests/convbert/test_modeling_convbert.py
+++ b/tests/models/convbert/test_modeling_convbert.py
@@ -21,8 +21,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/convbert/test_modeling_tf_convbert.py b/tests/models/convbert/test_modeling_tf_convbert.py
similarity index 99%
rename from tests/convbert/test_modeling_tf_convbert.py
rename to tests/models/convbert/test_modeling_tf_convbert.py
index 2ae29c3e4a5a4b..ae675b878ed145 100644
--- a/tests/convbert/test_modeling_tf_convbert.py
+++ b/tests/models/convbert/test_modeling_tf_convbert.py
@@ -19,8 +19,8 @@
from transformers import ConvBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/data2vec/__init__.py b/tests/models/convnext/__init__.py
similarity index 100%
rename from tests/data2vec/__init__.py
rename to tests/models/convnext/__init__.py
diff --git a/tests/convnext/test_feature_extraction_convnext.py b/tests/models/convnext/test_feature_extraction_convnext.py
similarity index 98%
rename from tests/convnext/test_feature_extraction_convnext.py
rename to tests/models/convnext/test_feature_extraction_convnext.py
index 8127d468b37713..f02341972ba03d 100644
--- a/tests/convnext/test_feature_extraction_convnext.py
+++ b/tests/models/convnext/test_feature_extraction_convnext.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/convnext/test_modeling_convnext.py b/tests/models/convnext/test_modeling_convnext.py
similarity index 97%
rename from tests/convnext/test_modeling_convnext.py
rename to tests/models/convnext/test_modeling_convnext.py
index 68a42f38af2d4d..46ef3ce71709cc 100644
--- a/tests/convnext/test_modeling_convnext.py
+++ b/tests/models/convnext/test_modeling_convnext.py
@@ -22,8 +22,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -158,6 +158,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="ConvNext does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="ConvNext does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/convnext/test_modeling_tf_convnext.py b/tests/models/convnext/test_modeling_tf_convnext.py
similarity index 94%
rename from tests/convnext/test_modeling_tf_convnext.py
rename to tests/models/convnext/test_modeling_tf_convnext.py
index 579c27dd27a624..bc84cd0a40007e 100644
--- a/tests/convnext/test_modeling_tf_convnext.py
+++ b/tests/models/convnext/test_modeling_tf_convnext.py
@@ -22,8 +22,8 @@
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
@@ -174,6 +174,13 @@ def test_model(self):
def test_attention_outputs(self):
pass
+ @unittest.skipIf(
+ not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
+ reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
+ )
+ def test_dataset_conversion(self):
+ super().test_dataset_conversion()
+
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
@@ -219,7 +226,10 @@ def recursive_check(tuple_object, dict_object):
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
- msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/deberta/__init__.py b/tests/models/cpm/__init__.py
similarity index 100%
rename from tests/deberta/__init__.py
rename to tests/models/cpm/__init__.py
diff --git a/tests/cpm/test_tokenization_cpm.py b/tests/models/cpm/test_tokenization_cpm.py
similarity index 100%
rename from tests/cpm/test_tokenization_cpm.py
rename to tests/models/cpm/test_tokenization_cpm.py
diff --git a/tests/deberta_v2/__init__.py b/tests/models/ctrl/__init__.py
similarity index 100%
rename from tests/deberta_v2/__init__.py
rename to tests/models/ctrl/__init__.py
diff --git a/tests/ctrl/test_modeling_ctrl.py b/tests/models/ctrl/test_modeling_ctrl.py
similarity index 93%
rename from tests/ctrl/test_modeling_ctrl.py
rename to tests/models/ctrl/test_modeling_ctrl.py
index af754399b81a82..ad6652f882d5bd 100644
--- a/tests/ctrl/test_modeling_ctrl.py
+++ b/tests/models/ctrl/test_modeling_ctrl.py
@@ -13,14 +13,15 @@
# limitations under the License.
+import gc
import unittest
from transformers import CTRLConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -181,6 +182,12 @@ def setUp(self):
self.model_tester = CTRLModelTester(self)
self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ torch.cuda.empty_cache()
+
def test_config(self):
self.config_tester.run_common_tests()
@@ -201,6 +208,12 @@ def test_model_from_pretrained(self):
@require_torch
class CTRLModelLanguageGenerationTest(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ torch.cuda.empty_cache()
+
@slow
def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl")
diff --git a/tests/ctrl/test_modeling_tf_ctrl.py b/tests/models/ctrl/test_modeling_tf_ctrl.py
similarity index 98%
rename from tests/ctrl/test_modeling_tf_ctrl.py
rename to tests/models/ctrl/test_modeling_tf_ctrl.py
index d17a97a3ad83cd..d3e82e57c9f27f 100644
--- a/tests/ctrl/test_modeling_tf_ctrl.py
+++ b/tests/models/ctrl/test_modeling_tf_ctrl.py
@@ -19,8 +19,8 @@
from transformers import CTRLConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/ctrl/test_tokenization_ctrl.py b/tests/models/ctrl/test_tokenization_ctrl.py
similarity index 97%
rename from tests/ctrl/test_tokenization_ctrl.py
rename to tests/models/ctrl/test_tokenization_ctrl.py
index 54eb8d218683a9..0bd4d8c8065cba 100644
--- a/tests/ctrl/test_tokenization_ctrl.py
+++ b/tests/models/ctrl/test_tokenization_ctrl.py
@@ -19,7 +19,7 @@
from transformers.models.ctrl.tokenization_ctrl import VOCAB_FILES_NAMES, CTRLTokenizer
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/decision_transformer/__init__.py b/tests/models/cvt/__init__.py
similarity index 100%
rename from tests/decision_transformer/__init__.py
rename to tests/models/cvt/__init__.py
diff --git a/tests/models/cvt/test_modeling_cvt.py b/tests/models/cvt/test_modeling_cvt.py
new file mode 100644
index 00000000000000..b88f22d982be78
--- /dev/null
+++ b/tests/models/cvt/test_modeling_cvt.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch CvT model. """
+
+
+import inspect
+import unittest
+from math import floor
+
+from transformers import CvtConfig
+from transformers.file_utils import cached_property, is_torch_available, is_vision_available
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import CvtForImageClassification, CvtModel
+ from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class CvtConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "embed_dim"))
+ self.parent.assertTrue(hasattr(config, "num_heads"))
+
+
+class CvtModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=64,
+ num_channels=3,
+ embed_dim=[16, 48, 96],
+ num_heads=[1, 3, 6],
+ depth=[1, 2, 10],
+ patch_sizes=[7, 3, 3],
+ patch_stride=[4, 2, 2],
+ patch_padding=[2, 1, 1],
+ stride_kv=[2, 2, 2],
+ cls_token=[False, False, True],
+ attention_drop_rate=[0.0, 0.0, 0.0],
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ is_training=True,
+ use_labels=True,
+ num_labels=2, # Check
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_sizes = patch_sizes
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.stride_kv = stride_kv
+ self.depth = depth
+ self.cls_token = cls_token
+ self.attention_drop_rate = attention_drop_rate
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return CvtConfig(
+ image_size=self.image_size,
+ num_labels=self.num_labels,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ patch_sizes=self.patch_sizes,
+ patch_padding=self.patch_padding,
+ patch_stride=self.patch_stride,
+ stride_kv=self.stride_kv,
+ depth=self.depth,
+ cls_token=self.cls_token,
+ attention_drop_rate=self.attention_drop_rate,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = CvtModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ image_size = (self.image_size, self.image_size)
+ height, width = image_size[0], image_size[1]
+ for i in range(len(self.depth)):
+ height = floor(((height + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1)
+ width = floor(((width + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dim[-1], height, width))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = CvtForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class CvtModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Cvt does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = CvtModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="Cvt does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip(reason="Cvt does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Cvt does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = len(self.model_tester.depth)
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # verify the first hidden states (first block)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-3:]),
+ [
+ self.model_tester.embed_dim[0],
+ self.model_tester.image_size // 4,
+ self.model_tester.image_size // 4,
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = CvtModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class CvtModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return AutoFeatureExtractor.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = CvtForImageClassification.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([0.9285, 0.9015, -0.3150]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/deit/__init__.py b/tests/models/data2vec/__init__.py
similarity index 100%
rename from tests/deit/__init__.py
rename to tests/models/data2vec/__init__.py
diff --git a/tests/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py
similarity index 98%
rename from tests/data2vec/test_modeling_data2vec_audio.py
rename to tests/models/data2vec/test_modeling_data2vec_audio.py
index ecadcb59039b71..e3fb96097d843f 100644
--- a/tests/data2vec/test_modeling_data2vec_audio.py
+++ b/tests/models/data2vec/test_modeling_data2vec_audio.py
@@ -24,8 +24,8 @@
from transformers import Data2VecAudioConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
@@ -116,7 +116,7 @@ def __init__(
self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
@@ -535,7 +535,7 @@ def _mock_init_weights(self, module):
def test_mask_feature_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained(
- "facebook/data2vec-audio-base-960h", mask_feature_prob=0.2, mask_feature_length=2
+ "hf-internal-testing/tiny-random-data2vec-seq-class", mask_feature_prob=0.2, mask_feature_length=2
)
model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained(
@@ -554,7 +554,7 @@ def test_mask_feature_prob_ctc(self):
attention_mask=batch["attention_mask"].to(torch_device),
).logits
- self.assertEqual(logits.shape, (4, 299, 32))
+ self.assertEqual(logits.shape, (4, 1498, 32))
def test_mask_time_prob_ctc(self):
model = Data2VecAudioForCTC.from_pretrained(
@@ -736,7 +736,8 @@ def test_inference_ctc_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with thousands of spectators were trivialities not worth thinking about",
"his instant of panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/data2vec/test_modeling_data2vec_text.py b/tests/models/data2vec/test_modeling_data2vec_text.py
similarity index 99%
rename from tests/data2vec/test_modeling_data2vec_text.py
rename to tests/models/data2vec/test_modeling_data2vec_text.py
index 8b27cefb74bdb4..f37d64044a02ce 100644
--- a/tests/data2vec/test_modeling_data2vec_text.py
+++ b/tests/models/data2vec/test_modeling_data2vec_text.py
@@ -20,9 +20,9 @@
from transformers import Data2VecTextConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin
if is_torch_available():
diff --git a/tests/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py
similarity index 97%
rename from tests/data2vec/test_modeling_data2vec_vision.py
rename to tests/models/data2vec/test_modeling_data2vec_vision.py
index 6005e9b379593b..8966b909970a28 100644
--- a/tests/data2vec/test_modeling_data2vec_vision.py
+++ b/tests/models/data2vec/test_modeling_data2vec_vision.py
@@ -23,8 +23,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
@@ -389,6 +389,10 @@ def check_hidden_states_output(inputs_dict, config, model_class):
check_hidden_states_output(inputs_dict, config, model_class)
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
+ # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
+ super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
+
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py
new file mode 100644
index 00000000000000..eb085af0d82b56
--- /dev/null
+++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py
@@ -0,0 +1,495 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TensorFlow Data2VecVision model. """
+
+import collections.abc
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import Data2VecVisionConfig
+from transformers.file_utils import cached_property, is_tf_available, is_vision_available
+from transformers.testing_utils import require_tf, require_vision, slow
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import (
+ TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
+ TFData2VecVisionModel,
+ )
+ from transformers.models.data2vec.modeling_tf_data2vec_vision import (
+ TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
+ )
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import BeitFeatureExtractor
+
+
+class TFData2VecVisionModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=100,
+ batch_size=13,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ out_indices=[0, 1, 2, 3],
+ ):
+ self.parent = parent
+ self.vocab_size = 100
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+ self.out_indices = out_indices
+ self.num_labels = num_labels
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ pixel_labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels, pixel_labels
+
+ def get_config(self):
+ return Data2VecVisionConfig(
+ vocab_size=self.vocab_size,
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ out_indices=self.out_indices,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
+ model = TFData2VecVisionModel(config=config)
+ result = model(pixel_values, training=False)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
+ image_size = (
+ self.image_size
+ if isinstance(self.image_size, collections.abc.Iterable)
+ else (self.image_size, self.image_size)
+ )
+ patch_size = (
+ self.patch_size
+ if isinstance(self.image_size, collections.abc.Iterable)
+ else (self.patch_size, self.patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
+ config.num_labels = self.type_sequence_label_size
+ model = TFData2VecVisionForImageClassification(config)
+
+ result = model(pixel_values, labels=labels, training=False)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
+ config.num_labels = self.num_labels
+ model = TFData2VecVisionForSemanticSegmentation(config)
+ result = model(pixel_values, training=False)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
+ )
+ result = model(pixel_values, labels=pixel_labels)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels, pixel_labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_keras_fit(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, _, _ = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values, "labels": tf.zeros((self.batch_size))}
+ return config, inputs_dict
+
+
+@require_tf
+class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Data2VecVision does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (TFData2VecVisionModel, TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation)
+ if is_tf_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_onnx = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = TFData2VecVisionModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=Data2VecVisionConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="Data2VecVision does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ # Data2VecVision does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_segmentation(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
+
+ @unittest.skip("Test was written for TF 1.x and isn't really relevant here")
+ def test_compile_tf_model(self):
+ pass
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ # in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+ patch_size = (
+ self.model_tester.patch_size
+ if isinstance(self.model_tester.patch_size, collections.abc.Iterable)
+ else (self.model_tester.patch_size, self.model_tester.patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
+
+ self.assertEqual(out_len + 1, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # Data2VecVision has a different seq_length
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+ patch_size = (
+ self.model_tester.patch_size
+ if isinstance(self.model_tester.patch_size, collections.abc.Iterable)
+ else (self.model_tester.patch_size, self.model_tester.patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_length = num_patches + 1
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # Overriding this method since the base method won't be compatible with Data2VecVision.
+ def test_keras_fit(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ # Since `TFData2VecVisionModel` cannot operate with the default `fit()` method.
+ if model_class.__name__ != "TFData2VecVisionModel":
+ model = model_class(config)
+ if getattr(model, "hf_compute_loss", None):
+ # Test that model correctly compute the loss with kwargs
+ _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit()
+
+ label_names = {"labels"}
+ self.assertGreater(len(label_names), 0, msg="No matching label names found!")
+ labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
+ inputs_minus_labels = {
+ key: val for key, val in prepared_for_class.items() if key not in label_names
+ }
+ self.assertGreater(len(inputs_minus_labels), 0)
+ model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
+
+ # Make sure the model fits without crashing regardless of where we pass the labels
+ history1 = model.fit(
+ prepared_for_class,
+ validation_data=prepared_for_class,
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ )
+ val_loss1 = history1.history["val_loss"][0]
+ history2 = model.fit(
+ inputs_minus_labels,
+ labels,
+ validation_data=(inputs_minus_labels, labels),
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ )
+ val_loss2 = history2.history["val_loss"][0]
+ self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
+
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
+ # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
+ super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
+
+ # Overriding this method since the base method won't be compatible with Data2VecVision.
+ def test_loss_computation(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ # Since `TFData2VecVisionModel` won't have labels against which we
+ # could compute loss.
+ if model_class.__name__ != "TFData2VecVisionModel":
+ model = model_class(config)
+ if getattr(model, "hf_compute_loss", None):
+ # The number of elements in the loss should be the same as the number of elements in the label
+ _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit()
+ added_label = prepared_for_class[
+ sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
+ ]
+ loss_size = tf.size(added_label)
+
+ # Test that model correctly compute the loss with kwargs
+ possible_input_names = {"input_ids", "pixel_values", "input_features"}
+ input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
+ model_input = prepared_for_class.pop(input_name)
+
+ loss = model(model_input, **prepared_for_class)[0]
+ self.assertEqual(loss.shape, [loss_size])
+
+ # Test that model correctly compute the loss with a dict
+ _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit()
+ loss = model(**prepared_for_class)[0]
+ self.assertEqual(loss.shape, [loss_size])
+
+ # Test that model correctly compute the loss with a tuple
+ label_keys = prepared_for_class.keys() - inputs_dict.keys()
+ signature = inspect.signature(model.call).parameters
+ signature_names = list(signature.keys())
+
+ # Create a dictionary holding the location of the tensors in the tuple
+ tuple_index_mapping = {0: input_name}
+ for label_key in label_keys:
+ label_key_index = signature_names.index(label_key)
+ tuple_index_mapping[label_key_index] = label_key
+ sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
+ # Initialize a list with their default values, update the values and convert to a tuple
+ list_input = []
+
+ for name in signature_names:
+ if name != "kwargs":
+ list_input.append(signature[name].default)
+
+ for index, value in sorted_tuple_index_mapping:
+ list_input[index] = prepared_for_class[value]
+
+ tuple_input = tuple(list_input)
+
+ # Send to model
+ loss = model(tuple_input[:-1])[0]
+
+ self.assertEqual(loss.shape, [loss_size])
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFData2VecVisionModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+@require_vision
+class TFData2VecVisionModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ BeitFeatureExtractor.from_pretrained("facebook/data2vec-vision-base-ft1k")
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head_imagenet_1k(self):
+ model = TFData2VecVisionForImageClassification.from_pretrained("facebook/data2vec-vision-base-ft1k")
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(**inputs)
+ logits = outputs.logits
+
+ # verify the logits
+ expected_shape = tf.convert_to_tensor([1, 1000])
+ self.assertEqual(logits.shape, expected_shape)
+
+ expected_slice = tf.convert_to_tensor([0.3277, -0.1395, 0.0911])
+
+ tf.debugging.assert_near(logits[0, :3], expected_slice, atol=1e-4)
+
+ expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
+ self.assertEqual(tf.nn.top_k(outputs.logits[0], 2).indices.numpy().tolist(), expected_top2)
diff --git a/tests/detr/__init__.py b/tests/models/deberta/__init__.py
similarity index 100%
rename from tests/detr/__init__.py
rename to tests/models/deberta/__init__.py
diff --git a/tests/deberta/test_modeling_deberta.py b/tests/models/deberta/test_modeling_deberta.py
similarity index 97%
rename from tests/deberta/test_modeling_deberta.py
rename to tests/models/deberta/test_modeling_deberta.py
index 1902f9389d8dbb..940a82db4398a2 100644
--- a/tests/deberta/test_modeling_deberta.py
+++ b/tests/models/deberta/test_modeling_deberta.py
@@ -17,8 +17,8 @@
from transformers import DebertaConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -130,6 +130,11 @@ def get_config(self):
pos_att_type=self.pos_att_type,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def check_loss_output(self, result):
self.parent.assertListEqual(list(result.loss.size()), [])
@@ -222,6 +227,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
+ fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
diff --git a/tests/deberta/test_modeling_tf_deberta.py b/tests/models/deberta/test_modeling_tf_deberta.py
similarity index 98%
rename from tests/deberta/test_modeling_tf_deberta.py
rename to tests/models/deberta/test_modeling_tf_deberta.py
index 7e2a3c3110eef9..c2584db30f1976 100644
--- a/tests/deberta/test_modeling_tf_deberta.py
+++ b/tests/models/deberta/test_modeling_tf_deberta.py
@@ -19,8 +19,8 @@
from transformers import DebertaConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/deberta/test_tokenization_deberta.py b/tests/models/deberta/test_tokenization_deberta.py
similarity index 90%
rename from tests/deberta/test_tokenization_deberta.py
rename to tests/models/deberta/test_tokenization_deberta.py
index 229ea22618139a..4aa53e13ff8d4e 100644
--- a/tests/deberta/test_tokenization_deberta.py
+++ b/tests/models/deberta/test_tokenization_deberta.py
@@ -22,7 +22,7 @@
from transformers.models.deberta.tokenization_deberta import VOCAB_FILES_NAMES
from transformers.testing_utils import slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@@ -88,6 +88,12 @@ def test_full_tokenizer(self):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+ def test_token_type_ids(self):
+ tokenizer = self.get_tokenizer()
+ tokd = tokenizer("Hello", "World")
+ expected_token_type_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
+ self.assertListEqual(tokd["token_type_ids"], expected_token_type_ids)
+
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
@@ -120,7 +126,9 @@ def test_tokenizer_integration(self):
sequences = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
- "The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
+ "The first one is a factorized embedding parameterization. By decomposing the large vocabulary"
+ " embedding matrix into two small matrices, we separate the size of the hidden layers from the size of"
+ " vocabulary embedding.",
]
encoding = tokenizer(sequences, padding=True)
@@ -149,7 +157,9 @@ def test_tokenizer_integration(self):
expected_decoded_sequence = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
- "The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
+ "The first one is a factorized embedding parameterization. By decomposing the large vocabulary"
+ " embedding matrix into two small matrices, we separate the size of the hidden layers from the size of"
+ " vocabulary embedding.",
]
self.assertDictEqual(encoding.data, expected_encoding)
diff --git a/tests/distilbert/__init__.py b/tests/models/deberta_v2/__init__.py
similarity index 100%
rename from tests/distilbert/__init__.py
rename to tests/models/deberta_v2/__init__.py
diff --git a/tests/deberta_v2/test_modeling_deberta_v2.py b/tests/models/deberta_v2/test_modeling_deberta_v2.py
similarity index 89%
rename from tests/deberta_v2/test_modeling_deberta_v2.py
rename to tests/models/deberta_v2/test_modeling_deberta_v2.py
index 48f3a673b69393..93436b901bb171 100644
--- a/tests/deberta_v2/test_modeling_deberta_v2.py
+++ b/tests/models/deberta_v2/test_modeling_deberta_v2.py
@@ -17,8 +17,8 @@
from transformers import DebertaV2Config, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -26,6 +26,7 @@
from transformers import (
DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
@@ -192,6 +193,23 @@ def create_and_check_deberta_for_question_answering(
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+ def create_and_check_deberta_for_multiple_choice(
+ self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ model = DebertaV2ForMultipleChoice(config=config)
+ model.to(torch_device)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_input_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ labels=choice_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
+
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -217,11 +235,13 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
DebertaV2ForQuestionAnswering,
+ DebertaV2ForMultipleChoice,
)
if is_torch_available()
else ()
)
+ fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
@@ -254,6 +274,10 @@ def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs)
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_deberta_for_multiple_choice(*config_and_inputs)
+
@slow
def test_model_from_pretrained(self):
for model_name in DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
diff --git a/tests/deberta_v2/test_modeling_tf_deberta_v2.py b/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py
similarity index 98%
rename from tests/deberta_v2/test_modeling_tf_deberta_v2.py
rename to tests/models/deberta_v2/test_modeling_tf_deberta_v2.py
index 4fd967c2fa6e76..b2cc8896e46ee5 100644
--- a/tests/deberta_v2/test_modeling_tf_deberta_v2.py
+++ b/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py
@@ -19,8 +19,8 @@
from transformers import DebertaV2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/deberta_v2/test_tokenization_deberta_v2.py b/tests/models/deberta_v2/test_tokenization_deberta_v2.py
similarity index 97%
rename from tests/deberta_v2/test_tokenization_deberta_v2.py
rename to tests/models/deberta_v2/test_tokenization_deberta_v2.py
index ee52c8706af905..c84034c7f0bc7d 100644
--- a/tests/deberta_v2/test_tokenization_deberta_v2.py
+++ b/tests/models/deberta_v2/test_tokenization_deberta_v2.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import DebertaV2Tokenizer, DebertaV2TokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/spiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
diff --git a/tests/dit/__init__.py b/tests/models/decision_transformer/__init__.py
similarity index 100%
rename from tests/dit/__init__.py
rename to tests/models/decision_transformer/__init__.py
diff --git a/tests/decision_transformer/test_modeling_decision_transformer.py b/tests/models/decision_transformer/test_modeling_decision_transformer.py
similarity index 97%
rename from tests/decision_transformer/test_modeling_decision_transformer.py
rename to tests/models/decision_transformer/test_modeling_decision_transformer.py
index 0843ce630ee30f..9124c64fa1d45a 100644
--- a/tests/decision_transformer/test_modeling_decision_transformer.py
+++ b/tests/models/decision_transformer/test_modeling_decision_transformer.py
@@ -21,9 +21,9 @@
from transformers import DecisionTransformerConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/dpr/__init__.py b/tests/models/deit/__init__.py
similarity index 100%
rename from tests/dpr/__init__.py
rename to tests/models/deit/__init__.py
diff --git a/tests/deit/test_feature_extraction_deit.py b/tests/models/deit/test_feature_extraction_deit.py
similarity index 98%
rename from tests/deit/test_feature_extraction_deit.py
rename to tests/models/deit/test_feature_extraction_deit.py
index 8eaa8313499c42..92a477f182fc3a 100644
--- a/tests/deit/test_feature_extraction_deit.py
+++ b/tests/models/deit/test_feature_extraction_deit.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py
similarity index 76%
rename from tests/deit/test_modeling_deit.py
rename to tests/models/deit/test_modeling_deit.py
index f8723c18756577..4559fa0c7127bf 100644
--- a/tests/deit/test_modeling_deit.py
+++ b/tests/models/deit/test_modeling_deit.py
@@ -24,8 +24,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -92,9 +92,9 @@ def __init__(
self.scope = scope
self.encoder_stride = encoder_stride
- # in DeiT, the expected seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
+ # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 2
+ self.seq_length = num_patches + 2
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -129,9 +129,7 @@ def create_and_check_model(self, config, pixel_values, labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@@ -181,8 +179,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="DeiT does not use inputs_embeds")
def test_inputs_embeds(self):
- # DeiT does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -210,94 +208,9 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- seq_length = self.model_tester.expected_seq_length
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
# special case for DeiTForImageClassificationWithTeacher model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@@ -403,10 +316,6 @@ def test_problem_types(self):
loss.backward()
- def test_for_image_classification(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
-
@slow
def test_model_from_pretrained(self):
for model_name in DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
diff --git a/tests/dpt/__init__.py b/tests/models/detr/__init__.py
similarity index 100%
rename from tests/dpt/__init__.py
rename to tests/models/detr/__init__.py
diff --git a/tests/detr/test_feature_extraction_detr.py b/tests/models/detr/test_feature_extraction_detr.py
similarity index 99%
rename from tests/detr/test_feature_extraction_detr.py
rename to tests/models/detr/test_feature_extraction_detr.py
index 0f88d57c73472e..58bde80fbbb11e 100644
--- a/tests/detr/test_feature_extraction_detr.py
+++ b/tests/models/detr/test_feature_extraction_detr.py
@@ -23,7 +23,7 @@
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py
similarity index 99%
rename from tests/detr/test_modeling_detr.py
rename to tests/models/detr/test_modeling_detr.py
index 45fb370199395c..7b0b7eeb75457c 100644
--- a/tests/detr/test_modeling_detr.py
+++ b/tests/models/detr/test_modeling_detr.py
@@ -23,9 +23,9 @@
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
if is_timm_available():
diff --git a/tests/electra/__init__.py b/tests/models/distilbert/__init__.py
similarity index 100%
rename from tests/electra/__init__.py
rename to tests/models/distilbert/__init__.py
diff --git a/tests/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py
similarity index 98%
rename from tests/distilbert/test_modeling_distilbert.py
rename to tests/models/distilbert/test_modeling_distilbert.py
index 2dfd31ac064b9c..9b4606b484ee55 100644
--- a/tests/distilbert/test_modeling_distilbert.py
+++ b/tests/models/distilbert/test_modeling_distilbert.py
@@ -19,8 +19,8 @@
from transformers import DistilBertConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/distilbert/test_modeling_flax_distilbert.py b/tests/models/distilbert/test_modeling_flax_distilbert.py
similarity index 98%
rename from tests/distilbert/test_modeling_flax_distilbert.py
rename to tests/models/distilbert/test_modeling_flax_distilbert.py
index 2ad10c07859ee8..e0f609b4ddf309 100644
--- a/tests/distilbert/test_modeling_flax_distilbert.py
+++ b/tests/models/distilbert/test_modeling_flax_distilbert.py
@@ -19,7 +19,7 @@
from transformers import DistilBertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/distilbert/test_modeling_tf_distilbert.py b/tests/models/distilbert/test_modeling_tf_distilbert.py
similarity index 98%
rename from tests/distilbert/test_modeling_tf_distilbert.py
rename to tests/models/distilbert/test_modeling_tf_distilbert.py
index 5266723f1f86a4..e52532d5618aae 100644
--- a/tests/distilbert/test_modeling_tf_distilbert.py
+++ b/tests/models/distilbert/test_modeling_tf_distilbert.py
@@ -19,8 +19,8 @@
from transformers import DistilBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/distilbert/test_tokenization_distilbert.py b/tests/models/distilbert/test_tokenization_distilbert.py
similarity index 100%
rename from tests/distilbert/test_tokenization_distilbert.py
rename to tests/models/distilbert/test_tokenization_distilbert.py
diff --git a/tests/encoder_decoder/__init__.py b/tests/models/dit/__init__.py
similarity index 100%
rename from tests/encoder_decoder/__init__.py
rename to tests/models/dit/__init__.py
diff --git a/tests/dit/test_modeling_dit.py b/tests/models/dit/test_modeling_dit.py
similarity index 100%
rename from tests/dit/test_modeling_dit.py
rename to tests/models/dit/test_modeling_dit.py
diff --git a/tests/flaubert/__init__.py b/tests/models/dpr/__init__.py
similarity index 100%
rename from tests/flaubert/__init__.py
rename to tests/models/dpr/__init__.py
diff --git a/tests/dpr/test_modeling_dpr.py b/tests/models/dpr/test_modeling_dpr.py
similarity index 98%
rename from tests/dpr/test_modeling_dpr.py
rename to tests/models/dpr/test_modeling_dpr.py
index 7aef57f753aad0..708f1d53c3a46e 100644
--- a/tests/dpr/test_modeling_dpr.py
+++ b/tests/models/dpr/test_modeling_dpr.py
@@ -20,8 +20,8 @@
from transformers import DPRConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/dpr/test_modeling_tf_dpr.py b/tests/models/dpr/test_modeling_tf_dpr.py
similarity index 98%
rename from tests/dpr/test_modeling_tf_dpr.py
rename to tests/models/dpr/test_modeling_tf_dpr.py
index ffce36efc3a662..86ef3837f1fa01 100644
--- a/tests/dpr/test_modeling_tf_dpr.py
+++ b/tests/models/dpr/test_modeling_tf_dpr.py
@@ -18,8 +18,8 @@
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/dpr/test_tokenization_dpr.py b/tests/models/dpr/test_tokenization_dpr.py
similarity index 100%
rename from tests/dpr/test_tokenization_dpr.py
rename to tests/models/dpr/test_tokenization_dpr.py
diff --git a/tests/fnet/__init__.py b/tests/models/dpt/__init__.py
similarity index 100%
rename from tests/fnet/__init__.py
rename to tests/models/dpt/__init__.py
diff --git a/tests/dpt/test_feature_extraction_dpt.py b/tests/models/dpt/test_feature_extraction_dpt.py
similarity index 98%
rename from tests/dpt/test_feature_extraction_dpt.py
rename to tests/models/dpt/test_feature_extraction_dpt.py
index 83abb59eec4f19..a0cf1cba23af5b 100644
--- a/tests/dpt/test_feature_extraction_dpt.py
+++ b/tests/models/dpt/test_feature_extraction_dpt.py
@@ -21,7 +21,7 @@
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py
similarity index 72%
rename from tests/dpt/test_modeling_dpt.py
rename to tests/models/dpt/test_modeling_dpt.py
index 08bb550e0e56a2..3266ea78a71aaa 100644
--- a/tests/dpt/test_modeling_dpt.py
+++ b/tests/models/dpt/test_modeling_dpt.py
@@ -23,8 +23,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -81,9 +81,9 @@ def __init__(
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
- # expected sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
+ # sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -118,9 +118,7 @@ def create_and_check_model(self, config, pixel_values, labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
@@ -167,8 +165,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="DPT does not use inputs_embeds")
def test_inputs_embeds(self):
- # DPT does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -204,97 +202,6 @@ def test_for_semantic_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in DPT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(len(outputs.attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # DPT has a different seq_length
- seq_len = self.model_tester.expected_seq_length
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_len, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_training(self):
for model_class in self.all_model_classes:
if model_class.__name__ == "DPTForDepthEstimation":
diff --git a/tests/fsmt/__init__.py b/tests/models/electra/__init__.py
similarity index 100%
rename from tests/fsmt/__init__.py
rename to tests/models/electra/__init__.py
diff --git a/tests/electra/test_modeling_electra.py b/tests/models/electra/test_modeling_electra.py
similarity index 99%
rename from tests/electra/test_modeling_electra.py
rename to tests/models/electra/test_modeling_electra.py
index 4a6a1b1357e488..9a6ba063ea3d4e 100644
--- a/tests/electra/test_modeling_electra.py
+++ b/tests/models/electra/test_modeling_electra.py
@@ -20,8 +20,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/electra/test_modeling_flax_electra.py b/tests/models/electra/test_modeling_flax_electra.py
similarity index 96%
rename from tests/electra/test_modeling_flax_electra.py
rename to tests/models/electra/test_modeling_flax_electra.py
index 390c8be39e219f..cd1a795a19efec 100644
--- a/tests/electra/test_modeling_flax_electra.py
+++ b/tests/models/electra/test_modeling_flax_electra.py
@@ -5,11 +5,12 @@
from transformers import ElectraConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.electra.modeling_flax_electra import (
+ FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxElectraModel,
+ FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForPreTraining,
FlaxElectraForTokenClassification,
diff --git a/tests/electra/test_modeling_tf_electra.py b/tests/models/electra/test_modeling_tf_electra.py
similarity index 99%
rename from tests/electra/test_modeling_tf_electra.py
rename to tests/models/electra/test_modeling_tf_electra.py
index ff2acd37e69f17..0c0c4f77ab3245 100644
--- a/tests/electra/test_modeling_tf_electra.py
+++ b/tests/models/electra/test_modeling_tf_electra.py
@@ -19,8 +19,8 @@
from transformers import ElectraConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/funnel/__init__.py b/tests/models/encoder_decoder/__init__.py
similarity index 100%
rename from tests/funnel/__init__.py
rename to tests/models/encoder_decoder/__init__.py
diff --git a/tests/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
similarity index 99%
rename from tests/encoder_decoder/test_modeling_encoder_decoder.py
rename to tests/models/encoder_decoder/test_modeling_encoder_decoder.py
index 8412ccb3894a4d..b356b3ee0ba112 100644
--- a/tests/encoder_decoder/test_modeling_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
@@ -20,13 +20,13 @@
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
+from ...test_modeling_common import ids_tensor
from ..bart.test_modeling_bart import BartStandaloneDecoderModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..bert_generation.test_modeling_bert_generation import BertGenerationEncoderTester
from ..gpt2.test_modeling_gpt2 import GPT2ModelTester
from ..prophetnet.test_modeling_prophetnet import ProphetNetStandaloneDecoderModelTester
from ..roberta.test_modeling_roberta import RobertaModelTester
-from ..test_modeling_common import ids_tensor
if is_torch_available():
diff --git a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py
similarity index 93%
rename from tests/encoder_decoder/test_modeling_flax_encoder_decoder.py
rename to tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py
index d0ab1a25d1d8da..ce7a79ead2fe63 100644
--- a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py
@@ -22,10 +22,10 @@
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
+from ...test_modeling_flax_common import ids_tensor
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
-from ..test_modeling_flax_common import ids_tensor
if is_flax_available():
@@ -33,6 +33,7 @@
AutoTokenizer,
EncoderDecoderConfig,
FlaxBartForCausalLM,
+ FlaxBertForCausalLM,
FlaxBertModel,
FlaxEncoderDecoderModel,
FlaxGPT2LMHeadModel,
@@ -545,6 +546,43 @@ def get_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
+@require_flax
+class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
+ def get_encoder_decoder_model(self, config, decoder_config):
+ encoder_model = FlaxBertModel(config)
+ decoder_model = FlaxBertForCausalLM(decoder_config)
+ return encoder_model, decoder_model
+
+ def prepare_config_and_inputs(self):
+ model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
+ model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
+ encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
+ decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ (config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
+ (
+ decoder_config,
+ decoder_input_ids,
+ decoder_attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ) = decoder_config_and_inputs
+
+ # make sure that cross attention layers are added
+ decoder_config.add_cross_attention = True
+ return {
+ "config": config,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_config": decoder_config,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+
+ def get_pretrained_model(self):
+ return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
+
+
@require_flax
class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
diff --git a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
similarity index 99%
rename from tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
rename to tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
index bedd72fe247b35..74eb59b4e016e1 100644
--- a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
@@ -24,11 +24,11 @@
from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_torch, slow, torch_device
+from ...test_modeling_tf_common import ids_tensor
from ..bert.test_modeling_tf_bert import TFBertModelTester
from ..gpt2.test_modeling_tf_gpt2 import TFGPT2ModelTester
from ..rembert.test_modeling_tf_rembert import TFRemBertModelTester
from ..roberta.test_modeling_tf_roberta import TFRobertaModelTester
-from ..test_modeling_tf_common import ids_tensor
if is_tf_available():
diff --git a/tests/glpn/__init__.py b/tests/models/flaubert/__init__.py
similarity index 100%
rename from tests/glpn/__init__.py
rename to tests/models/flaubert/__init__.py
diff --git a/tests/flaubert/test_modeling_flaubert.py b/tests/models/flaubert/test_modeling_flaubert.py
similarity index 99%
rename from tests/flaubert/test_modeling_flaubert.py
rename to tests/models/flaubert/test_modeling_flaubert.py
index 4c01abd459e835..da29cac6dd588c 100644
--- a/tests/flaubert/test_modeling_flaubert.py
+++ b/tests/models/flaubert/test_modeling_flaubert.py
@@ -19,8 +19,8 @@
from transformers import FlaubertConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/flaubert/test_modeling_tf_flaubert.py b/tests/models/flaubert/test_modeling_tf_flaubert.py
similarity index 98%
rename from tests/flaubert/test_modeling_tf_flaubert.py
rename to tests/models/flaubert/test_modeling_tf_flaubert.py
index 86bcd6ea64848a..09ba6f45d8d0b9 100644
--- a/tests/flaubert/test_modeling_tf_flaubert.py
+++ b/tests/models/flaubert/test_modeling_tf_flaubert.py
@@ -18,8 +18,8 @@
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/gpt2/__init__.py b/tests/models/flava/__init__.py
similarity index 100%
rename from tests/gpt2/__init__.py
rename to tests/models/flava/__init__.py
diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py
new file mode 100644
index 00000000000000..793aa913aeb04b
--- /dev/null
+++ b/tests/models/flava/test_feature_extraction_flava.py
@@ -0,0 +1,347 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaFeatureExtractor
+ from transformers.models.flava.feature_extraction_flava import (
+ FLAVA_CODEBOOK_MEAN,
+ FLAVA_CODEBOOK_STD,
+ FLAVA_IMAGE_MEAN,
+ FLAVA_IMAGE_STD,
+ )
+else:
+ FLAVA_IMAGE_MEAN = FLAVA_IMAGE_STD = FLAVA_CODEBOOK_MEAN = FLAVA_CODEBOOK_STD = None
+
+
+class FlavaFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=224,
+ do_center_crop=True,
+ crop_size=224,
+ resample=None,
+ do_normalize=True,
+ image_mean=FLAVA_IMAGE_MEAN,
+ image_std=FLAVA_IMAGE_STD,
+ input_size_patches=14,
+ total_mask_patches=75,
+ mask_group_max_patches=None,
+ mask_group_min_patches=16,
+ mask_group_min_aspect_ratio=0.3,
+ mask_group_max_aspect_ratio=None,
+ codebook_do_resize=True,
+ codebook_size=112,
+ codebook_resample=None,
+ codebook_do_center_crop=True,
+ codebook_crop_size=112,
+ codebook_do_map_pixels=True,
+ codebook_do_normalize=True,
+ codebook_image_mean=FLAVA_CODEBOOK_MEAN,
+ codebook_image_std=FLAVA_CODEBOOK_STD,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.do_resize = do_resize
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.size = size
+ self.resample = resample if resample is not None else Image.BICUBIC
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ self.codebook_do_resize = codebook_do_resize
+ self.codebook_size = codebook_size
+ self.codebook_resample = codebook_resample if codebook_resample is not None else Image.LANCZOS
+ self.codebook_do_center_crop = codebook_do_center_crop
+ self.codebook_crop_size = codebook_crop_size
+ self.codebook_do_map_pixels = codebook_do_map_pixels
+ self.codebook_do_normalize = codebook_do_normalize
+ self.codebook_image_mean = codebook_image_mean
+ self.codebook_image_std = codebook_image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "resample": self.resample,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "input_size_patches": self.input_size_patches,
+ "total_mask_patches": self.total_mask_patches,
+ "mask_group_max_patches": self.mask_group_max_patches,
+ "mask_group_min_patches": self.mask_group_min_patches,
+ "mask_group_min_aspect_ratio": self.mask_group_min_aspect_ratio,
+ "mask_group_max_aspect_ratio": self.mask_group_min_aspect_ratio,
+ "codebook_do_resize": self.codebook_do_resize,
+ "codebook_size": self.codebook_size,
+ "codebook_resample": self.codebook_resample,
+ "codebook_do_center_crop": self.codebook_do_center_crop,
+ "codebook_crop_size": self.codebook_crop_size,
+ "codebook_do_map_pixels": self.codebook_do_map_pixels,
+ "codebook_do_normalize": self.codebook_do_normalize,
+ "codebook_image_mean": self.codebook_image_mean,
+ "codebook_image_std": self.codebook_image_std,
+ }
+
+ def get_expected_image_size(self):
+ return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
+
+ def get_expected_mask_size(self):
+ return (
+ (self.input_size_patches, self.input_size_patches)
+ if not isinstance(self.input_size_patches, tuple)
+ else self.input_size_patches
+ )
+
+ def get_expected_codebook_image_size(self):
+ if not isinstance(self.codebook_size, tuple):
+ return (self.codebook_size, self.codebook_size)
+ else:
+ return self.codebook_size
+
+
+@require_torch
+@require_vision
+class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = FlavaFeatureExtractor if is_vision_available() else None
+ maxDiff = None
+
+ def setUp(self):
+ self.feature_extract_tester = FlavaFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "resample"))
+ self.assertTrue(hasattr(feature_extractor, "crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "masking_generator"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_size"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_resample"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_map_pixels"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "codebook_image_std"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def _test_call_framework(self, instance_class, prepare_kwargs):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs)
+ for image in image_inputs:
+ self.assertIsInstance(image, instance_class)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test masking
+ encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
+
+ def test_call_pytorch(self):
+ self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
+
+ def test_masking(self):
+ # Initialize feature_extractor
+ random.seed(1234)
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_image_mask=True, return_tensors="pt")
+ self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
+
+ def test_codebook_pixels(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size()
+ self.assertEqual(
+ encoded_images.codebook_pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_codebook_pixels=True, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size()
+ self.assertEqual(
+ encoded_images.codebook_pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py
new file mode 100644
index 00000000000000..62b89e3977c35a
--- /dev/null
+++ b/tests/models/flava/test_modeling_flava.py
@@ -0,0 +1,1228 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch FLAVA model. """
+
+
+import inspect
+import os
+import random
+import tempfile
+import unittest
+
+import numpy as np
+
+import requests
+from transformers import (
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+)
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import (
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaTextModel,
+ )
+ from transformers.models.flava.modeling_flava import (
+ FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ )
+else:
+ FlavaModel = None
+ FlavaForPreTraining = None
+ torch = {}
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaProcessor
+
+
+class FlavaImageModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ qkv_bias=True,
+ mask_token=True,
+ vocab_size=8192,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.mask_token = mask_token
+ self.vocab_size = vocab_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ num_patches = self.image_size // self.patch_size
+ bool_masked_pos = (
+ torch.rand((self.batch_size, num_patches, num_patches), device=pixel_values.device) < 0.9
+ ).long()
+ config = self.get_config()
+ return config, pixel_values, bool_masked_pos
+
+ def get_config(self):
+ return FlavaImageConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ qkv_bias=self.qkv_bias,
+ mask_token=self.mask_token,
+ vocab_size=self.vocab_size,
+ )
+
+ def create_and_check_model(self, config, pixel_values, bool_masked_pos):
+ model = FlavaImageModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values, bool_masked_pos)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.image_size, self.image_size)
+ patch_size = (self.patch_size, self.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, bool_masked_pos = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values, "bool_masked_pos": bool_masked_pos}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as FLAVA does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (FlavaImageModel,) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = FlavaImageModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaImageConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ # in FLAVA, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # FLAVA has a different seq_length
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_length = num_patches + 1
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ # skip this test as FlavaImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaImageModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaTextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ vocab_size=30522,
+ type_vocab_size=2,
+ max_position_embeddings=512,
+ position_embedding_type="absolute",
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ qkv_bias=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.position_embedding_type = position_embedding_type
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.pad_token_id = pad_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ token_type_ids = None
+
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask
+
+ def get_config(self):
+ return FlavaTextConfig(
+ vocab_size=self.vocab_size,
+ type_vocab_size=self.type_vocab_size,
+ max_position_embeddings=self.max_position_embeddings,
+ position_embedding_type=self.position_embedding_type,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ pad_token_id=self.pad_token_id,
+ qkv_bias=self.qkv_bias,
+ )
+
+ def create_and_check_model(self, config, input_ids, token_type_ids, input_mask):
+ model = FlavaTextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, input_mask = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaTextModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FlavaTextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaTextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FlavaTextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaTextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaTextModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaMultimodalModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=44,
+ use_input_mask=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ ce_ignore_index=-100,
+ use_cls_token=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.use_input_mask = use_input_mask
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.ce_ignore_index = ce_ignore_index
+ self.use_cls_token = use_cls_token
+
+ def prepare_config_and_inputs(self):
+ hidden_states = floats_tensor([self.batch_size, self.seq_length - 1, self.hidden_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ config = self.get_config()
+
+ return config, hidden_states, input_mask
+
+ def get_config(self):
+ return FlavaMultimodalConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ qkv_bias=self.qkv_bias,
+ use_cls_token=self.use_cls_token,
+ ce_ignore_index=self.ce_ignore_index,
+ )
+
+ def create_and_check_model(self, config, hidden_states, input_mask):
+ model = FlavaMultimodalModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(hidden_states, attention_mask=input_mask)
+ result = model(hidden_states)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, hidden_states, input_mask = config_and_inputs
+ inputs_dict = {"hidden_states": hidden_states, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaMultimodalModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FlavaMultimodalModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=FlavaMultimodalConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["hidden_states"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model_common_attributes(self):
+ # No embedding in multimodal model
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FlavaMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaMultimodalModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaImageCodebookTester:
+ def __init__(self, parent, batch_size=12, image_size=112, num_channels=3):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ return FlavaImageCodebookConfig()
+
+ def create_and_check_model(self, config, pixel_values):
+ model = FlavaImageCodebook(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.shape, (self.batch_size, config.vocab_size, self.image_size // 8, self.image_size // 8)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlavaImageCodebook,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ test_torchscript = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = FlavaImageCodebookTester(self)
+ self.config_tester = ConfigTester(self, config_class=FlavaImageCodebookConfig, has_text_modality=False)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ @unittest.skip(reason="Flava does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ def test_model_common_attributes(self):
+ # No embedding in multimodal model
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_hidden_states_output(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ # no attentions
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ def test_model_outputs_equivalence(self):
+ pass
+
+ # skip this test as FlavaImageCodebook has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FlavaImageCodebook has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaImageCodebook.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaModelTester:
+ model_class = FlavaModel
+
+ def __init__(
+ self,
+ parent,
+ is_training=True,
+ hidden_size=32,
+ projection_dim=32,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ ):
+ self.parent = parent
+ self.image_model_tester = FlavaImageModelTester(parent)
+ self.text_model_tester = FlavaTextModelTester(parent)
+ self.multimodal_model_tester = FlavaMultimodalModelTester(parent)
+ self.image_codebook_tester = FlavaImageCodebookTester(parent)
+ self.is_training = is_training
+ self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37)
+ self.hidden_size = hidden_size
+ self.projection_dim = projection_dim
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def prepare_config_and_inputs_for_common(self):
+ _, pixel_values, bool_masked_pos = self.image_model_tester.prepare_config_and_inputs()
+ _, input_ids, token_type_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+
+ config = self.get_config()
+
+ return config, {
+ "input_ids": input_ids,
+ "token_type_ids": token_type_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "bool_masked_pos": bool_masked_pos,
+ }
+
+ def get_config(self):
+ return FlavaConfig.from_configs(
+ self.image_model_tester.get_config(),
+ self.text_model_tester.get_config(),
+ self.multimodal_model_tester.get_config(),
+ self.image_codebook_tester.get_config(),
+ hidden_size=self.hidden_size,
+ projection_dim=self.projection_dim,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ )
+
+ def create_and_check_model(self, config, inputs):
+ self._test_model(config, inputs, test_image=True)
+ self._test_model(config, inputs, test_text=True)
+ self._test_model(config, inputs, test_image=True, test_text=True)
+
+ def _test_model(self, config, inputs, test_image=False, test_text=False):
+ model = self.model_class(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(
+ input_ids=inputs["input_ids"] if test_text else None,
+ attention_mask=inputs["attention_mask"] if test_text else None,
+ token_type_ids=inputs["token_type_ids"] if test_text else None,
+ pixel_values=inputs["pixel_values"] if test_image else None,
+ bool_masked_pos=inputs["bool_masked_pos"] if test_image else None,
+ )
+ image_size = (self.image_model_tester.image_size, self.image_model_tester.image_size)
+ patch_size = (self.image_model_tester.patch_size, self.image_model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ if test_image:
+ self.parent.assertEqual(
+ result.image_embeddings.shape,
+ (self.image_model_tester.batch_size, num_patches + 1, self.image_model_tester.hidden_size),
+ )
+ else:
+ self.parent.assertIsNone(result.image_embeddings)
+
+ if test_text:
+ self.parent.assertEqual(
+ result.text_embeddings.shape,
+ (
+ self.text_model_tester.batch_size,
+ self.text_model_tester.seq_length,
+ self.text_model_tester.hidden_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.text_embeddings)
+
+ if test_image and test_text:
+ self.parent.assertEqual(
+ result.multimodal_embeddings.shape,
+ (
+ self.multimodal_model_tester.batch_size,
+ self.text_model_tester.seq_length + num_patches + 2,
+ self.multimodal_model_tester.hidden_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.multimodal_embeddings)
+
+
+@require_torch
+class FlavaModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (FlavaModel,) if is_torch_available() else ()
+ class_for_tester = FlavaModelTester
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = self.class_for_tester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ # hidden_states are tested in individual model tests
+ def test_hidden_states_output(self):
+ pass
+
+ # input_embeds are tested in individual model tests
+ def test_inputs_embeds(self):
+ pass
+
+ # tested in individual model tests
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ # FlavaModel does not have input/output embeddings
+ def test_model_common_attributes(self):
+ pass
+
+ # override as the `logit_scale` parameter initilization is different for FLAVA
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # check if `logit_scale` is initilized as per the original implementation
+ if name == "logit_scale" or name == "flava.logit_scale":
+ self.assertAlmostEqual(
+ param.data.item(),
+ np.log(1 / 0.07),
+ delta=1e-3,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ configs_no_init.return_loss = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # FLAVA needs pixel_values
+
+ if "input_ids_masked" in inputs_dict:
+ # For pretraining
+ inputs = (input_ids, inputs_dict["input_ids_masked"], pixel_values)
+ else:
+ inputs = (input_ids, pixel_values)
+
+ traced_model = torch.jit.trace(model, inputs)
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ model.to(torch_device)
+ model.eval()
+
+ loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+ # Non persistent buffers won't be in original state dict
+ loaded_model_state_dict.pop("text_model.embeddings.token_type_ids", None)
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_load_image_text_config(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Save FlavaConfig and check if we can load FlavaImageConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ image_config = FlavaImageConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.image_config.to_dict(), image_config.to_dict())
+
+ # Save FlavaConfig and check if we can load FlavaTextConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ text_config = FlavaTextConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
+
+ # Save FlavaConfig and check if we can load FlavaMultimodalConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ multimodal_config = FlavaMultimodalConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.multimodal_config.to_dict(), multimodal_config.to_dict())
+
+ # overwrite from common since FlavaModel/TFFlavaModel return FLAVAOutput/TFFLAVAOutput
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FlavaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FlavaForPreTrainingTester(FlavaModelTester):
+ model_class = FlavaForPreTraining
+
+ def prepare_config_and_inputs_for_common(self):
+ _, pixel_values, bool_masked_pos = self.image_model_tester.prepare_config_and_inputs()
+ _, input_ids, token_type_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ config = self.get_config()
+
+ input_ids_masked = input_ids.detach().clone()
+ input_ids_masked[:, 1:3] = 100
+ mlm_labels = input_ids.detach().clone()
+ mlm_labels[:, :] = config.ce_ignore_index
+ mlm_labels[:, 1:3] = input_ids[:, 1:3]
+ mim_labels = torch.randint(
+ 0, self.image_model_tester.vocab_size, bool_masked_pos.size(), device=bool_masked_pos.device
+ ).long()
+ mim_labels[bool_masked_pos.ne(True)] = config.ce_ignore_index
+ itm_labels = torch.ones(mlm_labels.size(0), device=bool_masked_pos.device).long()
+
+ return config, {
+ "input_ids": input_ids,
+ "input_ids_masked": input_ids_masked,
+ "token_type_ids": token_type_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "bool_masked_pos": bool_masked_pos,
+ "mlm_labels": mlm_labels,
+ "mim_labels": mim_labels,
+ "itm_labels": itm_labels,
+ "return_loss": True,
+ }
+
+ def _test_model(self, config, inputs, test_image=False, test_text=False):
+ model = self.model_class(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(
+ input_ids=inputs["input_ids"] if test_text else None,
+ input_ids_masked=inputs["input_ids_masked"] if test_text else None,
+ attention_mask=inputs["attention_mask"] if test_text else None,
+ token_type_ids=inputs["token_type_ids"] if test_text else None,
+ pixel_values=inputs["pixel_values"] if test_image else None,
+ bool_masked_pos=inputs["bool_masked_pos"] if test_image else None,
+ mlm_labels=inputs["mlm_labels"],
+ mim_labels=inputs["mim_labels"],
+ itm_labels=inputs["itm_labels"],
+ return_loss=inputs["return_loss"],
+ )
+ image_size = (self.image_model_tester.image_size, self.image_model_tester.image_size)
+ patch_size = (self.image_model_tester.patch_size, self.image_model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ if test_image:
+ self.parent.assertEqual(
+ result.image_embeddings.shape,
+ (self.image_model_tester.batch_size, num_patches + 1, self.image_model_tester.hidden_size),
+ )
+ if not test_text:
+ self.parent.assertEqual(
+ result.loss_info.mim.dim(),
+ 0,
+ )
+ self.parent.assertEqual(
+ result.mim_logits.shape,
+ (inputs["bool_masked_pos"].sum().item(), self.image_model_tester.vocab_size),
+ )
+
+ else:
+ self.parent.assertIsNone(result.image_embeddings)
+
+ if test_text:
+ self.parent.assertEqual(
+ result.text_embeddings.shape,
+ (
+ self.text_model_tester.batch_size,
+ self.text_model_tester.seq_length,
+ self.text_model_tester.hidden_size,
+ ),
+ )
+ if not test_image:
+ self.parent.assertEqual(result.loss_info.mlm.dim(), 0)
+ self.parent.assertEqual(
+ result.mlm_logits.shape,
+ (
+ (inputs["mlm_labels"] != self.multimodal_model_tester.ce_ignore_index).sum().item(),
+ self.text_model_tester.vocab_size,
+ ),
+ )
+ else:
+ self.parent.assertIsNone(result.text_embeddings)
+
+ if test_image and test_text:
+ self.parent.assertEqual(
+ result.multimodal_masked_embeddings.shape,
+ (
+ self.multimodal_model_tester.batch_size,
+ self.text_model_tester.seq_length + num_patches + 2,
+ self.multimodal_model_tester.hidden_size,
+ ),
+ )
+ self.parent.assertEqual(
+ result.itm_logits.shape,
+ (self.text_model_tester.batch_size, 2),
+ )
+ self.parent.assertEqual(
+ result.mmm_text_logits.shape,
+ (
+ (inputs["mlm_labels"] != self.multimodal_model_tester.ce_ignore_index).sum().item(),
+ self.text_model_tester.vocab_size,
+ ),
+ )
+ self.parent.assertEqual(
+ result.mmm_image_logits.shape,
+ (inputs["bool_masked_pos"].sum().item(), self.image_model_tester.vocab_size),
+ )
+ self.parent.assertEqual(
+ result.contrastive_logits_per_image.shape,
+ (self.image_model_tester.batch_size, self.text_model_tester.batch_size),
+ )
+ self.parent.assertEqual(
+ result.contrastive_logits_per_text.shape,
+ (self.text_model_tester.batch_size, self.image_model_tester.batch_size),
+ )
+
+ for item in [
+ result.loss_info.global_contrastive,
+ result.loss_info.itm,
+ result.loss_info.mmm_text,
+ result.loss_info.mmm_image,
+ ]:
+ self.parent.assertEqual(item.dim(), 0)
+
+ for item in [result.loss_info.mim, result.loss_info.mlm]:
+ self.parent.assertIsNone(item)
+
+ else:
+ self.parent.assertIsNone(result.multimodal_masked_embeddings)
+ for item in [
+ result.loss_info.global_contrastive,
+ result.loss_info.itm,
+ result.loss_info.mmm_text,
+ result.loss_info.mmm_image,
+ ]:
+ self.parent.assertIsNone(item)
+
+ self.parent.assertIsNone(result.multimodal_embeddings)
+
+
+@require_torch
+class FlavaForPreTrainingTest(FlavaModelTest):
+ all_model_classes = (FlavaForPreTraining,) if is_torch_available() else ()
+ class_for_tester = FlavaForPreTrainingTester
+ test_torchscript = False
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@require_vision
+@require_torch
+class FlavaModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = "facebook/flava-full"
+ model = FlavaModel.from_pretrained(model_name).to(torch_device)
+ processor = FlavaProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"],
+ images=[image, image],
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ ).to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs, return_dict=True)
+
+ # verify the embeddings
+ self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
+ self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
+ self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -3988.51367, places=4)
+
+
+@require_vision
+@require_torch
+class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = "facebook/flava-full"
+ model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device)
+ processor = FlavaProcessor.from_pretrained(model_name)
+ torch.manual_seed(1)
+ random.seed(1)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"],
+ images=[image, image],
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ return_codebook_pixels=True,
+ return_image_mask=True,
+ )
+ inputs["input_ids_masked"] = inputs["input_ids"].clone()
+ inputs["input_ids_masked"][0, 4:6] = 103
+ inputs["mlm_labels"] = inputs["input_ids"].clone()
+ inputs["mlm_labels"][:, :] = -100
+ inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
+ inputs = inputs.to(torch_device)
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ self.assertEqual(
+ outputs.contrastive_logits_per_image.shape,
+ torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
+ )
+ self.assertEqual(
+ outputs.contrastive_logits_per_text.shape,
+ torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
+ )
+
+ expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
+ self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
+ self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
+ self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
+ self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
diff --git a/tests/models/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py
new file mode 100644
index 00000000000000..21cc84d5f299a6
--- /dev/null
+++ b/tests/models/flava/test_processor_flava.py
@@ -0,0 +1,234 @@
+# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import random
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import BertTokenizer, BertTokenizerFast
+from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_vision
+from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FlavaFeatureExtractor, FlavaProcessor
+ from transformers.models.flava.feature_extraction_flava import (
+ FLAVA_CODEBOOK_MEAN,
+ FLAVA_CODEBOOK_STD,
+ FLAVA_IMAGE_MEAN,
+ FLAVA_IMAGE_STD,
+ )
+
+
+@require_vision
+class FlavaProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.tmpdirname = tempfile.mkdtemp()
+
+ # fmt: off
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"]
+ # fmt: on
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write("".join([x + "\n" for x in vocab_tokens]))
+
+ feature_extractor_map = {
+ "image_mean": FLAVA_IMAGE_MEAN,
+ "image_std": FLAVA_IMAGE_STD,
+ "do_normalize": True,
+ "do_resize": True,
+ "size": 224,
+ "do_center_crop": True,
+ "crop_size": 224,
+ "input_size_patches": 14,
+ "total_mask_patches": 75,
+ "mask_group_max_patches": None,
+ "mask_group_min_patches": 16,
+ "mask_group_min_aspect_ratio": 0.3,
+ "mask_group_max_aspect_ratio": None,
+ "codebook_do_resize": True,
+ "codebook_size": 112,
+ "codebook_resample": None,
+ "codebook_do_center_crop": True,
+ "codebook_crop_size": 112,
+ "codebook_do_map_pixels": True,
+ "codebook_do_normalize": True,
+ "codebook_image_mean": FLAVA_CODEBOOK_MEAN,
+ "codebook_image_std": FLAVA_CODEBOOK_STD,
+ }
+
+ self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
+ json.dump(feature_extractor_map, fp)
+
+ def get_tokenizer(self, **kwargs):
+ return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return FlavaFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def prepare_image_inputs(self):
+ """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
+ or a list of PyTorch tensors if one specifies torchify=True.
+ """
+
+ image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
+
+ image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+
+ return image_inputs
+
+ def test_save_load_pretrained_default(self):
+ tokenizer_slow = self.get_tokenizer()
+ tokenizer_fast = self.get_rust_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor_slow = FlavaProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor)
+ processor_slow.save_pretrained(self.tmpdirname)
+ processor_slow = FlavaProcessor.from_pretrained(self.tmpdirname, use_fast=False)
+
+ processor_fast = FlavaProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor)
+ processor_fast.save_pretrained(self.tmpdirname)
+ processor_fast = FlavaProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
+ self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
+ self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor_slow.feature_extractor, FlavaFeatureExtractor)
+ self.assertIsInstance(processor_fast.feature_extractor, FlavaFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = FlavaProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = FlavaProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, FlavaFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ image_input = self.prepare_image_inputs()
+
+ input_feat_extract = feature_extractor(image_input, return_tensors="np")
+ input_processor = processor(images=image_input, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ # With rest of the args
+ random.seed(1234)
+ input_feat_extract = feature_extractor(
+ image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
+ )
+ random.seed(1234)
+ input_processor = processor(
+ images=image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
+ )
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_processor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+
+ inputs = processor(text=input_str, images=image_input)
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
+
+ # add extra args
+ inputs = processor(text=input_str, images=image_input, return_codebook_pixels=True, return_image_mask=True)
+
+ self.assertListEqual(
+ list(inputs.keys()),
+ [
+ "input_ids",
+ "token_type_ids",
+ "attention_mask",
+ "pixel_values",
+ "codebook_pixel_values",
+ "bool_masked_pos",
+ ],
+ )
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
diff --git a/tests/gpt_neo/__init__.py b/tests/models/fnet/__init__.py
similarity index 100%
rename from tests/gpt_neo/__init__.py
rename to tests/models/fnet/__init__.py
diff --git a/tests/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py
similarity index 97%
rename from tests/fnet/test_modeling_fnet.py
rename to tests/models/fnet/test_modeling_fnet.py
index b175bf3a540364..974d7c2d4e5d63 100644
--- a/tests/fnet/test_modeling_fnet.py
+++ b/tests/models/fnet/test_modeling_fnet.py
@@ -22,8 +22,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -333,7 +333,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
diff --git a/tests/fnet/test_tokenization_fnet.py b/tests/models/fnet/test_tokenization_fnet.py
similarity index 98%
rename from tests/fnet/test_tokenization_fnet.py
rename to tests/models/fnet/test_tokenization_fnet.py
index a620ccf1f3b1be..0058155bdb6d3e 100644
--- a/tests/fnet/test_tokenization_fnet.py
+++ b/tests/models/fnet/test_tokenization_fnet.py
@@ -13,17 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
from transformers import FNetTokenizer, FNetTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow, tooslow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow, tooslow
from transformers.tokenization_utils import AddedToken
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/spiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
diff --git a/tests/gptj/__init__.py b/tests/models/fsmt/__init__.py
similarity index 100%
rename from tests/gptj/__init__.py
rename to tests/models/fsmt/__init__.py
diff --git a/tests/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py
similarity index 99%
rename from tests/fsmt/test_modeling_fsmt.py
rename to tests/models/fsmt/test_modeling_fsmt.py
index c452a5729b100d..9e487b609aae09 100644
--- a/tests/fsmt/test_modeling_fsmt.py
+++ b/tests/models/fsmt/test_modeling_fsmt.py
@@ -23,9 +23,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
diff --git a/tests/fsmt/test_tokenization_fsmt.py b/tests/models/fsmt/test_tokenization_fsmt.py
similarity index 99%
rename from tests/fsmt/test_tokenization_fsmt.py
rename to tests/models/fsmt/test_tokenization_fsmt.py
index f5e3d4b6cf7572..7407c2fbc86795 100644
--- a/tests/fsmt/test_tokenization_fsmt.py
+++ b/tests/models/fsmt/test_tokenization_fsmt.py
@@ -22,7 +22,7 @@
from transformers.testing_utils import slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
# using a different tiny model than the one used for default params defined in init to ensure proper testing
diff --git a/tests/herbert/__init__.py b/tests/models/funnel/__init__.py
similarity index 100%
rename from tests/herbert/__init__.py
rename to tests/models/funnel/__init__.py
diff --git a/tests/funnel/test_modeling_funnel.py b/tests/models/funnel/test_modeling_funnel.py
similarity index 99%
rename from tests/funnel/test_modeling_funnel.py
rename to tests/models/funnel/test_modeling_funnel.py
index 73f5ec4b1778e0..c0520203a97f8a 100644
--- a/tests/funnel/test_modeling_funnel.py
+++ b/tests/models/funnel/test_modeling_funnel.py
@@ -20,8 +20,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
diff --git a/tests/funnel/test_modeling_tf_funnel.py b/tests/models/funnel/test_modeling_tf_funnel.py
similarity index 99%
rename from tests/funnel/test_modeling_tf_funnel.py
rename to tests/models/funnel/test_modeling_tf_funnel.py
index c3ae3788d61e4e..422985f7a6fb2f 100644
--- a/tests/funnel/test_modeling_tf_funnel.py
+++ b/tests/models/funnel/test_modeling_tf_funnel.py
@@ -19,8 +19,8 @@
from transformers import FunnelConfig, is_tf_available
from transformers.testing_utils import require_tf
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/funnel/test_tokenization_funnel.py b/tests/models/funnel/test_tokenization_funnel.py
similarity index 97%
rename from tests/funnel/test_tokenization_funnel.py
rename to tests/models/funnel/test_tokenization_funnel.py
index 592f19b4114128..e46928a538fdf9 100644
--- a/tests/funnel/test_tokenization_funnel.py
+++ b/tests/models/funnel/test_tokenization_funnel.py
@@ -21,7 +21,7 @@
from transformers.models.funnel.tokenization_funnel import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/hubert/__init__.py b/tests/models/glpn/__init__.py
similarity index 100%
rename from tests/hubert/__init__.py
rename to tests/models/glpn/__init__.py
diff --git a/tests/glpn/test_feature_extraction_glpn.py b/tests/models/glpn/test_feature_extraction_glpn.py
similarity index 97%
rename from tests/glpn/test_feature_extraction_glpn.py
rename to tests/models/glpn/test_feature_extraction_glpn.py
index c903491ce103aa..4e7f2bdf5c7834 100644
--- a/tests/glpn/test_feature_extraction_glpn.py
+++ b/tests/models/glpn/test_feature_extraction_glpn.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/glpn/test_modeling_glpn.py b/tests/models/glpn/test_modeling_glpn.py
similarity index 99%
rename from tests/glpn/test_modeling_glpn.py
rename to tests/models/glpn/test_modeling_glpn.py
index 323215d78b02fb..7d34a7f4f30dbd 100644
--- a/tests/glpn/test_modeling_glpn.py
+++ b/tests/models/glpn/test_modeling_glpn.py
@@ -22,8 +22,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
diff --git a/tests/ibert/__init__.py b/tests/models/gpt2/__init__.py
similarity index 100%
rename from tests/ibert/__init__.py
rename to tests/models/gpt2/__init__.py
diff --git a/tests/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py
similarity index 98%
rename from tests/gpt2/test_modeling_flax_gpt2.py
rename to tests/models/gpt2/test_modeling_flax_gpt2.py
index 7be52b5a1151d2..a86377e42f7c27 100644
--- a/tests/gpt2/test_modeling_flax_gpt2.py
+++ b/tests/models/gpt2/test_modeling_flax_gpt2.py
@@ -22,8 +22,8 @@
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
similarity index 98%
rename from tests/gpt2/test_modeling_gpt2.py
rename to tests/models/gpt2/test_modeling_gpt2.py
index cea36400b2b872..0960daff836034 100644
--- a/tests/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -21,9 +21,9 @@
from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -166,6 +166,11 @@ def get_config(
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py
similarity index 87%
rename from tests/gpt2/test_modeling_tf_gpt2.py
rename to tests/models/gpt2/test_modeling_tf_gpt2.py
index 2092fb8feb37a9..93b48ce8f29948 100644
--- a/tests/gpt2/test_modeling_tf_gpt2.py
+++ b/tests/models/gpt2/test_modeling_tf_gpt2.py
@@ -18,9 +18,9 @@
from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
-from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
if is_tf_available():
@@ -456,7 +456,7 @@ def test_lm_generate_greedy_distilgpt2_batch_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
@@ -465,12 +465,12 @@ def test_lm_generate_greedy_distilgpt2_batch_special(self):
"repetition_penalty": 1.3,
}
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
- "Yesterday was a very busy day for the first time since I started writing this post",
+ "Yesterday was a very interesting time for the world to see how much of this is",
]
self.assertListEqual(output_strings, expected_output_string)
@@ -483,7 +483,7 @@ def test_lm_generate_sample_distilgpt2_batch_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"do_sample": True,
@@ -498,13 +498,13 @@ def test_lm_generate_sample_distilgpt2_batch_special(self):
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
- "Today is a beautiful day and we will make you feel very hot/terrific in all",
- "Yesterday was another solid success as news coverage became standard American domestic television hit.",
+ "Today is a beautiful day and we will make you feel very hot/terrific in all your",
+ "Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
]
self.assertListEqual(output_strings, expected_output_string)
@@ -517,7 +517,7 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
@@ -526,37 +526,69 @@ def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
"num_beams": 2,
}
- output_ids = model.generate(input_ids, **generation_kwargs)
+ output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and a great day for all of us.\n\nIām",
- "Yesterday was the first day of the year for the second time in a row,",
+ "Yesterday was the first time that a person has been arrested in the United States for",
]
self.assertListEqual(output_strings, expected_output_string)
+ @slow
+ def test_lm_generate_distilgpt2_left_padding(self):
+ """Tests that the generated text is the same, regarless of left padding"""
+ model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
+
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ generation_kwargs = {
+ "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
+ "no_repeat_ngram_size": 2,
+ "do_sample": False,
+ "repetition_penalty": 1.3,
+ }
+ expected_output_string = (
+ "Today is a beautiful day and I am so happy to be able take part in this amazing event."
+ )
+
+ sentences = ["Today is a beautiful day and"]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
+ # using default length
+ output_ids = model.generate(**input_ids, **generation_kwargs)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertEqual(output_strings[0], expected_output_string)
+
+ sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
+ # longer max length to capture the full length (remember: it is left padded)
+ output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
+ output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertEqual(output_strings[0], expected_output_string)
+
@slow
def test_lm_generate_gpt2_greedy_xla(self):
- # TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
- # the underlying problem)
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
- sentences = ["The dog"]
+ sentences = ["The dog", "The flying machine"]
expected_output_strings = [
- "The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
+ "The dog was found in a field near the intersection of West and West Streets.\n\nThe",
+ "The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
]
- input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
- output_ids = model.generate(input_ids, do_sample=False)
+ output_ids = model.generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
xla_generate = tf.function(model.generate, jit_compile=True)
- output_ids = xla_generate(input_ids, do_sample=False)
+ output_ids = xla_generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
@@ -574,20 +606,24 @@ def test_lm_generate_gpt2_sample_xla(self):
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
- sentence = ["The dog"]
+ sentence = ["The dog", "The flying machine"]
expected_output_string = [
- "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
+ "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
+ " puppies",
+ "The flying machine was made by an artist who found it difficult to control it as it did not use",
]
expected_output_string_xla = [
- "The dog has been named in connection with the murder of a 20-year-old man in!"
+ "The dog has been named in connection with the murder of a 20-year-old man in",
+ "The flying machine is a new and improved system to operate and operate a new system and system "
+ "system system",
]
- input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
+ input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
- output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
+ output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
- output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
+ output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
diff --git a/tests/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py
similarity index 69%
rename from tests/gpt2/test_tokenization_gpt2.py
rename to tests/models/gpt2/test_tokenization_gpt2.py
index 96f18e166c14ae..d76bc75ccbd582 100644
--- a/tests/gpt2/test_tokenization_gpt2.py
+++ b/tests/models/gpt2/test_tokenization_gpt2.py
@@ -22,7 +22,7 @@
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
@@ -175,6 +175,78 @@ def test_padding(self, max_length=15):
padding="max_length",
)
+ def test_padding_if_pad_token_set_slow(self):
+ tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, pad_token="")
+
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input looooooooong", "This is a simple input"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input loooooong", "This is a simple input"),
+ ("This is a simple pair loooooong", "This is a simple pair"),
+ ]
+
+ pad_token_id = tokenizer.pad_token_id
+
+ out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np")
+ out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np")
+ out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np")
+ out_p2 = tokenizer(p2, padding=True, truncate=True, return_tensors="np")
+
+ # s
+ # test single string max_length padding
+ self.assertEqual(out_s["input_ids"].shape[-1], 30)
+ self.assertTrue(pad_token_id in out_s["input_ids"])
+ self.assertTrue(0 in out_s["attention_mask"])
+
+ # s2
+ # test automatic padding
+ self.assertEqual(out_s2["input_ids"].shape[-1], 33)
+ # long slice doesn't have padding
+ self.assertFalse(pad_token_id in out_s2["input_ids"][0])
+ self.assertFalse(0 in out_s2["attention_mask"][0])
+ # short slice does have padding
+ self.assertTrue(pad_token_id in out_s2["input_ids"][1])
+ self.assertTrue(0 in out_s2["attention_mask"][1])
+
+ # p
+ # test single pair max_length padding
+ self.assertEqual(out_p["input_ids"].shape[-1], 60)
+ self.assertTrue(pad_token_id in out_p["input_ids"])
+ self.assertTrue(0 in out_p["attention_mask"])
+
+ # p2
+ # test automatic padding pair
+ self.assertEqual(out_p2["input_ids"].shape[-1], 52)
+ # long slice pair doesn't have padding
+ self.assertFalse(pad_token_id in out_p2["input_ids"][0])
+ self.assertFalse(0 in out_p2["attention_mask"][0])
+ # short slice pair does have padding
+ self.assertTrue(pad_token_id in out_p2["input_ids"][1])
+ self.assertTrue(0 in out_p2["attention_mask"][1])
+
+ def test_add_bos_token_slow(self):
+ bos_token = "$$$"
+ tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, bos_token=bos_token, add_bos_token=True)
+
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+
+ bos_token_id = tokenizer.bos_token_id
+
+ out_s = tokenizer(s)
+ out_s2 = tokenizer(s2)
+
+ self.assertEqual(out_s.input_ids[0], bos_token_id)
+ self.assertTrue(all(o[0] == bos_token_id for o in out_s2.input_ids))
+
+ decode_s = tokenizer.decode(out_s.input_ids)
+ decode_s2 = tokenizer.batch_decode(out_s2.input_ids)
+
+ self.assertEqual(decode_s.split()[0], bos_token)
+ self.assertTrue(all(d.split()[0] == bos_token for d in decode_s2))
+
# tokenizer has no padding token
def test_padding_different_model_input_name(self):
pass
diff --git a/tests/imagegpt/__init__.py b/tests/models/gpt_neo/__init__.py
similarity index 100%
rename from tests/imagegpt/__init__.py
rename to tests/models/gpt_neo/__init__.py
diff --git a/tests/gpt_neo/test_modeling_flax_gpt_neo.py b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py
similarity index 98%
rename from tests/gpt_neo/test_modeling_flax_gpt_neo.py
rename to tests/models/gpt_neo/test_modeling_flax_gpt_neo.py
index 580138a7b31e8b..74659c56a8e4ea 100644
--- a/tests/gpt_neo/test_modeling_flax_gpt_neo.py
+++ b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py
@@ -22,8 +22,8 @@
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/gpt_neo/test_modeling_gpt_neo.py b/tests/models/gpt_neo/test_modeling_gpt_neo.py
similarity index 98%
rename from tests/gpt_neo/test_modeling_gpt_neo.py
rename to tests/models/gpt_neo/test_modeling_gpt_neo.py
index 4135d7857cd1c9..16a775e2731b4e 100644
--- a/tests/gpt_neo/test_modeling_gpt_neo.py
+++ b/tests/models/gpt_neo/test_modeling_gpt_neo.py
@@ -21,9 +21,9 @@
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -151,6 +151,11 @@ def get_config(self):
attention_types=self.attention_types,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/layoutlm/__init__.py b/tests/models/gpt_neox/__init__.py
similarity index 100%
rename from tests/layoutlm/__init__.py
rename to tests/models/gpt_neox/__init__.py
diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py
new file mode 100644
index 00000000000000..a4fb95384e83ff
--- /dev/null
+++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py
@@ -0,0 +1,245 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch GPTNeoX model. """
+
+
+import unittest
+
+from transformers import GPTNeoXConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import GPTNeoXForCausalLM, GPTNeoXModel
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+class GPTNeoXModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_labels = None
+ if self.use_labels:
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask, token_labels
+
+ def get_config(self):
+ return GPTNeoXConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+
+ return config, input_ids, input_mask, token_labels
+
+ def create_and_check_model(self, config, input_ids, input_mask):
+ model = GPTNeoXModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ _ = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_model_as_decoder(self, config, input_ids, input_mask):
+ config.add_cross_attention = True
+ model = GPTNeoXModel(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels):
+ model = GPTNeoXForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=token_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
+ config.is_decoder = True
+ model = GPTNeoXForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True)
+ output_from_no_past = output_from_no_past["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, input_mask, token_labels = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
+ test_pruning = False
+ test_missing_keys = False
+ test_model_parallel = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = GPTNeoXModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(config, input_ids, input_mask)
+
+ def test_model_as_decoder(self):
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
+
+ def test_model_as_decoder_with_default_input_mask(self):
+ # This regression test was failing with PyTorch < 1.3
+ config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
+
+ input_mask = None
+
+ self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = GPTNeoXModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class GPTNeoXModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference_masked_lm(self):
+ model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b")
+ input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
+ output = model(input_ids)[0]
+
+ vocab_size = model.config.vocab_size
+
+ expected_shape = torch.Size((1, 6, vocab_size))
+ self.assertEqual(output.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[[33.8045, 2.3958, 34.2816], [63.7805, 4.8332, 63.5882], [66.9116, 5.2198, 63.1185]]]
+ )
+
+ self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
diff --git a/tests/layoutlmv2/__init__.py b/tests/models/gptj/__init__.py
similarity index 100%
rename from tests/layoutlmv2/__init__.py
rename to tests/models/gptj/__init__.py
diff --git a/tests/gptj/test_modeling_flax_gptj.py b/tests/models/gptj/test_modeling_flax_gptj.py
similarity index 98%
rename from tests/gptj/test_modeling_flax_gptj.py
rename to tests/models/gptj/test_modeling_flax_gptj.py
index 3a6d71b7bfb103..0b98ed5670d37a 100644
--- a/tests/gptj/test_modeling_flax_gptj.py
+++ b/tests/models/gptj/test_modeling_flax_gptj.py
@@ -22,8 +22,8 @@
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
similarity index 98%
rename from tests/gptj/test_modeling_gptj.py
rename to tests/models/gptj/test_modeling_gptj.py
index 0cabb2342b12b2..b8b088d42f1e27 100644
--- a/tests/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -20,9 +20,9 @@
from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -155,6 +155,11 @@ def get_config(self):
rotary_dim=self.rotary_dim,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py
similarity index 99%
rename from tests/gptj/test_modeling_tf_gptj.py
rename to tests/models/gptj/test_modeling_tf_gptj.py
index 63feffb8c62eca..0d9af0b65087a3 100644
--- a/tests/gptj/test_modeling_tf_gptj.py
+++ b/tests/models/gptj/test_modeling_tf_gptj.py
@@ -19,9 +19,9 @@
from transformers import AutoTokenizer, GPTJConfig, is_tf_available
from transformers.testing_utils import require_tf, slow, tooslow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
-from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
if is_tf_available():
diff --git a/tests/layoutxlm/__init__.py b/tests/models/herbert/__init__.py
similarity index 100%
rename from tests/layoutxlm/__init__.py
rename to tests/models/herbert/__init__.py
diff --git a/tests/herbert/test_tokenization_herbert.py b/tests/models/herbert/test_tokenization_herbert.py
similarity index 98%
rename from tests/herbert/test_tokenization_herbert.py
rename to tests/models/herbert/test_tokenization_herbert.py
index d4a30e241d7c00..3e8d3ac6ea2993 100644
--- a/tests/herbert/test_tokenization_herbert.py
+++ b/tests/models/herbert/test_tokenization_herbert.py
@@ -22,7 +22,7 @@
from transformers.models.herbert.tokenization_herbert import VOCAB_FILES_NAMES
from transformers.testing_utils import get_tests_dir, require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/led/__init__.py b/tests/models/hubert/__init__.py
similarity index 100%
rename from tests/led/__init__.py
rename to tests/models/hubert/__init__.py
diff --git a/tests/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py
similarity index 86%
rename from tests/hubert/test_modeling_hubert.py
rename to tests/models/hubert/test_modeling_hubert.py
index 0bc854114de4a5..1e27690bd47a71 100644
--- a/tests/hubert/test_modeling_hubert.py
+++ b/tests/models/hubert/test_modeling_hubert.py
@@ -16,15 +16,19 @@
import math
+import os
+import pickle
+import tempfile
import unittest
import pytest
from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
+from transformers.utils import is_torch_fx_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -45,6 +49,9 @@
)
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
class HubertModelTester:
def __init__(
@@ -106,7 +113,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
@@ -299,6 +306,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_headmasking = False
@@ -417,6 +425,117 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
+ # Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "attention_mask",
+ "decoder_attention_mask",
+ "decoder_input_ids",
+ "input_features",
+ "input_ids",
+ "input_values",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = [
+ "attention_mask",
+ "bbox",
+ "input_features",
+ "input_ids",
+ "input_values",
+ "pixel_values",
+ "token_type_ids",
+ "visual_feats",
+ "visual_pos",
+ ]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except Exception as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
diff --git a/tests/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py
similarity index 98%
rename from tests/hubert/test_modeling_tf_hubert.py
rename to tests/models/hubert/test_modeling_tf_hubert.py
index 5331395b89ca63..871d466d97129b 100644
--- a/tests/hubert/test_modeling_tf_hubert.py
+++ b/tests/models/hubert/test_modeling_tf_hubert.py
@@ -25,8 +25,8 @@
from transformers import is_tf_available
from transformers.testing_utils import require_soundfile, require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
@@ -539,7 +539,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant of panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/longformer/__init__.py b/tests/models/ibert/__init__.py
similarity index 100%
rename from tests/longformer/__init__.py
rename to tests/models/ibert/__init__.py
diff --git a/tests/ibert/test_modeling_ibert.py b/tests/models/ibert/test_modeling_ibert.py
old mode 100755
new mode 100644
similarity index 99%
rename from tests/ibert/test_modeling_ibert.py
rename to tests/models/ibert/test_modeling_ibert.py
index 41819d973be630..78ba4d4604d1a4
--- a/tests/ibert/test_modeling_ibert.py
+++ b/tests/models/ibert/test_modeling_ibert.py
@@ -20,8 +20,8 @@
from transformers import IBertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -116,6 +116,11 @@ def get_config(self):
quant_mode=True,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
diff --git a/tests/luke/__init__.py b/tests/models/imagegpt/__init__.py
similarity index 100%
rename from tests/luke/__init__.py
rename to tests/models/imagegpt/__init__.py
diff --git a/tests/imagegpt/test_feature_extraction_imagegpt.py b/tests/models/imagegpt/test_feature_extraction_imagegpt.py
similarity index 98%
rename from tests/imagegpt/test_feature_extraction_imagegpt.py
rename to tests/models/imagegpt/test_feature_extraction_imagegpt.py
index 4d1ca087d80b93..1dd3786759fd6c 100644
--- a/tests/imagegpt/test_feature_extraction_imagegpt.py
+++ b/tests/models/imagegpt/test_feature_extraction_imagegpt.py
@@ -25,7 +25,7 @@
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin
if is_torch_available():
diff --git a/tests/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py
similarity index 98%
rename from tests/imagegpt/test_modeling_imagegpt.py
rename to tests/models/imagegpt/test_modeling_imagegpt.py
index c570e6192779b7..528532d4cd813d 100644
--- a/tests/imagegpt/test_modeling_imagegpt.py
+++ b/tests/models/imagegpt/test_modeling_imagegpt.py
@@ -24,9 +24,9 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -171,6 +171,12 @@ def get_config(
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 513
+ config.max_position_embeddings = 1024
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/lxmert/__init__.py b/tests/models/layoutlm/__init__.py
similarity index 100%
rename from tests/lxmert/__init__.py
rename to tests/models/layoutlm/__init__.py
diff --git a/tests/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py
similarity index 99%
rename from tests/layoutlm/test_modeling_layoutlm.py
rename to tests/models/layoutlm/test_modeling_layoutlm.py
index faf4458cc81ee8..e2d949611d78e8 100644
--- a/tests/layoutlm/test_modeling_layoutlm.py
+++ b/tests/models/layoutlm/test_modeling_layoutlm.py
@@ -19,8 +19,8 @@
from transformers import LayoutLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else None
)
+ fx_compatible = True
def setUp(self):
self.model_tester = LayoutLMModelTester(self)
diff --git a/tests/layoutlm/test_modeling_tf_layoutlm.py b/tests/models/layoutlm/test_modeling_tf_layoutlm.py
similarity index 98%
rename from tests/layoutlm/test_modeling_tf_layoutlm.py
rename to tests/models/layoutlm/test_modeling_tf_layoutlm.py
index 90e2b4fcf16973..fb230aab56e820 100644
--- a/tests/layoutlm/test_modeling_tf_layoutlm.py
+++ b/tests/models/layoutlm/test_modeling_tf_layoutlm.py
@@ -20,8 +20,8 @@
from transformers import LayoutLMConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/layoutlm/test_tokenization_layoutlm.py b/tests/models/layoutlm/test_tokenization_layoutlm.py
similarity index 97%
rename from tests/layoutlm/test_tokenization_layoutlm.py
rename to tests/models/layoutlm/test_tokenization_layoutlm.py
index dab51216586bac..3663355ee50717 100644
--- a/tests/layoutlm/test_tokenization_layoutlm.py
+++ b/tests/models/layoutlm/test_tokenization_layoutlm.py
@@ -21,7 +21,7 @@
from transformers.models.layoutlm.tokenization_layoutlm import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/m2m_100/__init__.py b/tests/models/layoutlmv2/__init__.py
similarity index 100%
rename from tests/m2m_100/__init__.py
rename to tests/models/layoutlmv2/__init__.py
diff --git a/tests/layoutlmv2/test_feature_extraction_layoutlmv2.py b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
similarity index 99%
rename from tests/layoutlmv2/test_feature_extraction_layoutlmv2.py
rename to tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
index ca3bbf1cdc9d16..59c30d779c5f57 100644
--- a/tests/layoutlmv2/test_feature_extraction_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_pytesseract, require_torch
from transformers.utils import is_pytesseract_available, is_torch_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/layoutlmv2/test_modeling_layoutlmv2.py b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py
similarity index 99%
rename from tests/layoutlmv2/test_modeling_layoutlmv2.py
rename to tests/models/layoutlmv2/test_modeling_layoutlmv2.py
index 708a433989ae6e..bfcd729df15335 100644
--- a/tests/layoutlmv2/test_modeling_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py
@@ -23,8 +23,8 @@
from transformers.testing_utils import require_detectron2, require_torch, slow, torch_device
from transformers.utils import is_detectron2_available, is_torch_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/layoutlmv2/test_processor_layoutlmv2.py b/tests/models/layoutlmv2/test_processor_layoutlmv2.py
similarity index 90%
rename from tests/layoutlmv2/test_processor_layoutlmv2.py
rename to tests/models/layoutlmv2/test_processor_layoutlmv2.py
index e822d177ca6613..4f686155adc715 100644
--- a/tests/layoutlmv2/test_processor_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_processor_layoutlmv2.py
@@ -133,6 +133,39 @@ def test_save_load_pretrained_additional_features(self):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
+ @slow
+ def test_overflowing_tokens(self):
+ # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
+
+ from datasets import load_dataset
+
+ # set up
+ datasets = load_dataset("nielsr/funsd")
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
+
+ def preprocess_data(examples):
+ images = [Image.open(path).convert("RGB") for path in examples["image_path"]]
+ words = examples["words"]
+ boxes = examples["bboxes"]
+ word_labels = examples["ner_tags"]
+ encoded_inputs = processor(
+ images,
+ words,
+ boxes=boxes,
+ word_labels=word_labels,
+ padding="max_length",
+ truncation=True,
+ return_overflowing_tokens=True,
+ stride=50,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ )
+ return encoded_inputs
+
+ train_data = preprocess_data(datasets["train"])
+
+ self.assertEqual(len(train_data["image"]), len(train_data["input_ids"]))
+
# different use cases tests
@require_torch
@@ -182,10 +215,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president ā introductory remarks ā lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -203,10 +237,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] 7 itc limited report and accounts 2013 itc ā s brands : an asset for the nation the consumer needs and aspirations they fulfil, the benefit they generate for millions across itc ā s value chains, the future - ready capabilities that support them, and the value that they create for the country, have made itc ā s brands national assets, adding to india ā s competitiveness. it is itc ā s aspiration to be the no 1 fmcg player in the country, driven by its new fmcg businesses. a recent nielsen report has highlighted that itc's new fmcg businesses are the fastest growing among the top consumer goods companies operating in india. itc takes justifiable pride that, along with generating economic value, these celebrated indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. di wills * ; love delightfully soft skin? aia ans source : https : / / www. industrydocuments. ucsf. edu / docs / snbx0223 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
@slow
@@ -233,7 +268,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = "[CLS] hello world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -248,7 +283,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = "[CLS] hello world [SEP] [PAD] [PAD] [PAD]"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -287,7 +322,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = "[CLS] weirdly world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify labels
@@ -309,7 +344,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = "[CLS] my name is niels [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -349,10 +384,11 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = "[CLS] what's his name? [SEP] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president ā introductory remarks ā lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -367,8 +403,9 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
expected_decoding = "[CLS] what's the time [SEP] 7 itc limited report and accounts 2013 itc ā s [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -401,7 +438,7 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = "[CLS] what's his name? [SEP] hello world [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -417,11 +454,11 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = "[CLS] how old is he? [SEP] hello world [SEP] [PAD] [PAD] [PAD]"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
expected_decoding = "[CLS] what's the time [SEP] my name is niels [SEP]"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
diff --git a/tests/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
similarity index 99%
rename from tests/layoutlmv2/test_tokenization_layoutlmv2.py
rename to tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
index 249660d4a3f2d4..78f78c33e7f7f1 100644
--- a/tests/layoutlmv2/test_tokenization_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
@@ -33,7 +33,7 @@
)
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
-from ..test_tokenization_common import (
+from ...test_tokenization_common import (
SMALL_TRAINING_CORPUS,
TokenizerTesterMixin,
filter_non_english,
@@ -181,7 +181,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
@@ -1634,11 +1634,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
@@ -1923,7 +1921,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1937,7 +1936,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
@@ -2232,7 +2232,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -2244,7 +2245,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
# Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
diff --git a/tests/marian/__init__.py b/tests/models/layoutlmv3/__init__.py
similarity index 100%
rename from tests/marian/__init__.py
rename to tests/models/layoutlmv3/__init__.py
diff --git a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
new file mode 100644
index 00000000000000..9d05a4b6658efc
--- /dev/null
+++ b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
@@ -0,0 +1,213 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_pytesseract, require_torch
+from transformers.utils import is_pytesseract_available, is_torch_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_pytesseract_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor
+
+
+class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ apply_ocr=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.apply_ocr = apply_ocr
+
+ def prepare_feat_extract_dict(self):
+ return {"do_resize": self.do_resize, "size": self.size, "apply_ocr": self.apply_ocr}
+
+
+@require_torch
+@require_pytesseract
+class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = LayoutLMv3FeatureExtractor if is_pytesseract_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = LayoutLMv3FeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoding = feature_extractor(image_inputs[0], return_tensors="pt")
+ self.assertEqual(
+ encoding.pixel_values.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ self.assertIsInstance(encoding.words, list)
+ self.assertIsInstance(encoding.boxes, list)
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_LayoutLMv3_integration_test(self):
+ # with apply_OCR = True
+ feature_extractor = LayoutLMv3FeatureExtractor()
+
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
+
+ image = Image.open(ds[0]["file"]).convert("RGB")
+
+ encoding = feature_extractor(image, return_tensors="pt")
+
+ self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
+ self.assertEqual(len(encoding.words), len(encoding.boxes))
+
+ # fmt: off
+ # the words and boxes were obtained with Tesseract 4.1.1
+ expected_words = [['11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.', '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58', 'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00', 'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for', 'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to', 'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.', '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)', 'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President', 'āIntroductory', 'Remarksā', 'Lee', 'A.', 'Waller,', 'TRRF', 'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with', 'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific', 'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF', 'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers', 'which', 'the', 'public', 'refrigerated', 'warehousing', 'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from', 'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of', 'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.', 'Cecil,', 'University', 'of', 'Georgia', 'College', 'of', 'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University', 'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,', 'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen', 'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert', 'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers', 'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF', 'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer']] # noqa: E231
+ expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69], [141, 75, 216, 88], [230, 79, 280, 88], [142, 260, 218, 273], [230, 261, 255, 273], [143, 279, 218, 290], [231, 282, 290, 291], [143, 342, 218, 354], [231, 345, 289, 355], [202, 362, 227, 373], [143, 379, 220, 392], [231, 382, 291, 394], [144, 714, 220, 726], [231, 715, 256, 726], [144, 732, 220, 745], [232, 736, 291, 747], [144, 769, 218, 782], [231, 770, 256, 782], [141, 788, 202, 801], [215, 791, 274, 804], [143, 826, 204, 838], [215, 826, 240, 838], [142, 844, 202, 857], [215, 847, 274, 859], [334, 57, 427, 69], [440, 57, 522, 69], [369, 75, 461, 88], [469, 75, 516, 88], [528, 76, 562, 88], [570, 76, 667, 88], [675, 75, 711, 87], [721, 79, 778, 88], [789, 75, 840, 88], [369, 97, 470, 107], [484, 94, 507, 106], [518, 94, 562, 107], [576, 94, 655, 110], [668, 94, 792, 109], [804, 95, 829, 107], [369, 113, 465, 125], [477, 116, 547, 125], [562, 113, 658, 125], [671, 116, 748, 125], [761, 113, 811, 125], [369, 131, 465, 143], [477, 133, 548, 143], [563, 130, 698, 145], [710, 130, 802, 146], [336, 171, 412, 183], [423, 171, 572, 183], [582, 170, 716, 184], [728, 171, 817, 187], [829, 171, 844, 186], [338, 197, 482, 212], [507, 196, 557, 209], [569, 196, 595, 208], [610, 196, 702, 209], [505, 214, 583, 226], [595, 214, 656, 227], [670, 215, 807, 227], [335, 259, 543, 274], [556, 259, 708, 272], [372, 279, 422, 291], [435, 279, 460, 291], [474, 279, 574, 292], [587, 278, 664, 291], [676, 278, 738, 291], [751, 279, 834, 291], [372, 298, 434, 310], [335, 341, 483, 354], [497, 341, 655, 354], [667, 341, 728, 354], [740, 341, 825, 354], [335, 360, 430, 372], [442, 360, 534, 372], [545, 359, 687, 372], [697, 360, 754, 372], [765, 360, 823, 373], [334, 378, 428, 391], [440, 378, 577, 394], [590, 378, 705, 391], [720, 378, 801, 391], [334, 397, 400, 409], [370, 416, 529, 429], [544, 416, 576, 432], [587, 416, 665, 428], [677, 416, 814, 429], [372, 435, 452, 450], [465, 434, 495, 447], [511, 434, 600, 447], [611, 436, 637, 447], [649, 436, 694, 451], [705, 438, 824, 447], [369, 453, 452, 466], [464, 454, 509, 466], [522, 453, 611, 469], [625, 453, 792, 469], [370, 472, 556, 488], [570, 472, 684, 487], [697, 472, 718, 485], [732, 472, 835, 488], [369, 490, 411, 503], [425, 490, 484, 503], [496, 490, 635, 506], [645, 490, 707, 503], [718, 491, 761, 503], [771, 490, 840, 503], [336, 510, 374, 521], [388, 510, 447, 522], [460, 510, 489, 521], [503, 510, 580, 522], [592, 509, 736, 525], [745, 509, 770, 522], [781, 509, 840, 522], [338, 528, 434, 541], [448, 528, 596, 541], [609, 527, 687, 540], [700, 528, 792, 541], [336, 546, 397, 559], [407, 546, 431, 559], [443, 546, 525, 560], [537, 546, 680, 562], [688, 546, 714, 559], [722, 546, 837, 562], [336, 565, 449, 581], [461, 565, 485, 577], [497, 565, 665, 581], [681, 565, 718, 577], [732, 565, 837, 580], [337, 584, 438, 597], [452, 583, 521, 596], [535, 584, 677, 599], [690, 583, 787, 596], [801, 583, 825, 596], [338, 602, 478, 615], [492, 602, 530, 614], [543, 602, 638, 615], [650, 602, 676, 614], [688, 602, 788, 615], [802, 602, 843, 614], [337, 621, 502, 633], [516, 621, 615, 637], [629, 621, 774, 636], [789, 621, 827, 633], [337, 639, 418, 652], [432, 640, 571, 653], [587, 639, 731, 655], [743, 639, 769, 652], [780, 639, 841, 652], [338, 658, 440, 673], [455, 658, 491, 670], [508, 658, 602, 671], [616, 658, 638, 670], [654, 658, 835, 674], [337, 677, 429, 689], [337, 714, 482, 726], [495, 714, 548, 726], [561, 714, 683, 726], [338, 770, 461, 782], [474, 769, 554, 785], [489, 788, 562, 803], [576, 788, 643, 801], [656, 787, 751, 804], [764, 788, 844, 801], [334, 825, 421, 838], [430, 824, 574, 838], [584, 824, 723, 841], [335, 844, 450, 857], [464, 843, 583, 860], [628, 862, 755, 875], [769, 861, 848, 878]]] # noqa: E231
+ # fmt: on
+
+ self.assertListEqual(encoding.words, expected_words)
+ self.assertListEqual(encoding.boxes, expected_boxes)
+
+ # with apply_OCR = False
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+
+ encoding = feature_extractor(image, return_tensors="pt")
+
+ self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
diff --git a/tests/models/layoutlmv3/test_modeling_layoutlmv3.py b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py
new file mode 100644
index 00000000000000..d5c8d42d22177a
--- /dev/null
+++ b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py
@@ -0,0 +1,399 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch LayoutLMv3 model. """
+
+import copy
+import unittest
+
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ LayoutLMv3Config,
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ )
+ from transformers.models.layoutlmv3.modeling_layoutlmv3 import LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor
+
+
+class LayoutLMv3ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ num_channels=3,
+ image_size=4,
+ patch_size=2,
+ text_seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=36,
+ num_hidden_layers=3,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ coordinate_size=6,
+ shape_size=6,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ range_bbox=1000,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.text_seq_length = text_seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.coordinate_size = coordinate_size
+ self.shape_size = shape_size
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+ self.range_bbox = range_bbox
+
+ # LayoutLMv3's sequence length equals the number of text tokens + number of patches + 1 (we add 1 for the CLS token)
+ self.text_seq_length = text_seq_length
+ self.image_seq_length = (image_size // patch_size) ** 2 + 1
+ self.seq_length = self.text_seq_length + self.image_seq_length
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.text_seq_length], self.vocab_size)
+
+ bbox = ids_tensor([self.batch_size, self.text_seq_length, 4], self.range_bbox)
+ # Ensure that bbox is legal
+ for i in range(bbox.shape[0]):
+ for j in range(bbox.shape[1]):
+ if bbox[i, j, 3] < bbox[i, j, 1]:
+ t = bbox[i, j, 3]
+ bbox[i, j, 3] = bbox[i, j, 1]
+ bbox[i, j, 1] = t
+ if bbox[i, j, 2] < bbox[i, j, 0]:
+ t = bbox[i, j, 2]
+ bbox[i, j, 2] = bbox[i, j, 0]
+ bbox[i, j, 0] = t
+
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.text_seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.text_seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
+
+ config = LayoutLMv3Config(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ coordinate_size=self.coordinate_size,
+ shape_size=self.shape_size,
+ input_size=self.image_size,
+ patch_size=self.patch_size,
+ )
+
+ return config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+
+ def create_and_check_model(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ model = LayoutLMv3Model(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # text + image
+ result = model(input_ids, pixel_values=pixel_values)
+ result = model(
+ input_ids, bbox=bbox, pixel_values=pixel_values, attention_mask=input_mask, token_type_ids=token_type_ids
+ )
+ result = model(input_ids, bbox=bbox, pixel_values=pixel_values, token_type_ids=token_type_ids)
+ result = model(input_ids, bbox=bbox, pixel_values=pixel_values)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ # text only
+ result = model(input_ids)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
+ )
+
+ # image only
+ result = model(pixel_values=pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.image_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_for_sequence_classification(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ config.num_labels = self.num_labels
+ model = LayoutLMv3ForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=sequence_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def create_and_check_for_token_classification(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ config.num_labels = self.num_labels
+ model = LayoutLMv3ForTokenClassification(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=token_labels,
+ )
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.num_labels))
+
+ def create_and_check_for_question_answering(
+ self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
+ ):
+ model = LayoutLMv3ForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ bbox=bbox,
+ pixel_values=pixel_values,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ bbox,
+ pixel_values,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ ) = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "bbox": bbox,
+ "pixel_values": pixel_values,
+ "token_type_ids": token_type_ids,
+ "attention_mask": input_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase):
+
+ test_pruning = False
+ test_torchscript = False
+ test_mismatched_shapes = False
+
+ all_model_classes = (
+ (
+ LayoutLMv3Model,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3ForQuestionAnswering,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ def setUp(self):
+ self.model_tester = LayoutLMv3ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LayoutLMv3Config, hidden_size=37)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = copy.deepcopy(inputs_dict)
+ if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ inputs_dict = {
+ k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
+ if isinstance(v, torch.Tensor) and v.ndim > 1
+ else v
+ for k, v in inputs_dict.items()
+ }
+ if return_labels:
+ if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
+ elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ elif model_class in [
+ *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
+ ]:
+ inputs_dict["labels"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ elif model_class in [
+ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
+ ]:
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.text_seq_length),
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ return inputs_dict
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_various_embeddings(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for type in ["absolute", "relative_key", "relative_key_query"]:
+ config_and_inputs[0].position_embedding_type = type
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LayoutLMv3Model.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+class LayoutLMv3ModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return LayoutLMv3FeatureExtractor(apply_ocr=False) if is_vision_available() else None
+
+ @slow
+ def test_inference_no_head(self):
+ model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
+
+ input_ids = torch.tensor([[1, 2]])
+ bbox = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).unsqueeze(0)
+
+ # forward pass
+ outputs = model(
+ input_ids=input_ids.to(torch_device),
+ bbox=bbox.to(torch_device),
+ pixel_values=pixel_values.to(torch_device),
+ )
+
+ # verify the logits
+ expected_shape = torch.Size((1, 199, 768))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-0.0529, 0.3618, 0.1632], [-0.1587, -0.1667, -0.0400], [-0.1557, -0.1671, -0.0505]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
diff --git a/tests/models/layoutlmv3/test_processor_layoutlmv3.py b/tests/models/layoutlmv3/test_processor_layoutlmv3.py
new file mode 100644
index 00000000000000..a01b0a00cd9047
--- /dev/null
+++ b/tests/models/layoutlmv3/test_processor_layoutlmv3.py
@@ -0,0 +1,446 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+from typing import List
+
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
+from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
+from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow
+from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available
+
+
+if is_pytesseract_available():
+ from PIL import Image
+
+ from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3Processor
+
+
+@require_pytesseract
+@require_tokenizers
+class LayoutLMv3ProcessorTest(unittest.TestCase):
+ tokenizer_class = LayoutLMv3Tokenizer
+ rust_tokenizer_class = LayoutLMv3TokenizerFast
+
+ def setUp(self):
+ # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ ]
+ self.tmpdirname = tempfile.mkdtemp()
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ feature_extractor_map = {
+ "do_resize": True,
+ "size": 224,
+ "apply_ocr": True,
+ }
+
+ self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(feature_extractor_map) + "\n")
+
+ def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
+ return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
+
+ def get_feature_extractor(self, **kwargs):
+ return LayoutLMv3FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = LayoutLMv3Processor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, (LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast))
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = LayoutLMv3Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer())
+ processor.save_pretrained(self.tmpdirname)
+
+ # slow tokenizer
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
+
+ processor = LayoutLMv3Processor.from_pretrained(
+ self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, LayoutLMv3Tokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+ # fast tokenizer
+ tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
+
+ processor = LayoutLMv3Processor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, LayoutLMv3TokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
+
+
+# different use cases tests
+@require_torch
+@require_pytesseract
+class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
+ @cached_property
+ def get_images(self):
+ # we verify our implementation on 2 document images from the DocVQA dataset
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
+
+ image_1 = Image.open(ds[0]["file"]).convert("RGB")
+ image_2 = Image.open(ds[1]["file"]).convert("RGB")
+
+ return image_1, image_2
+
+ @cached_property
+ def get_tokenizers(self):
+ slow_tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ fast_tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ return [slow_tokenizer, fast_tokenizer]
+
+ @slow
+ def test_processor_case_1(self):
+ # case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True
+
+ feature_extractor = LayoutLMv3FeatureExtractor()
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ input_feat_extract = feature_extractor(images[0], return_tensors="pt")
+ input_processor = processor(images[0], return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify image
+ self.assertAlmostEqual(
+ input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
+ )
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ input_feat_extract = feature_extractor(images, return_tensors="pt")
+ input_processor = processor(images, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify images
+ self.assertAlmostEqual(
+ input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
+ )
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITCās Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITCās value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITCās brands national assets, adding to Indiaās competitiveness. It is ITCās aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ @slow
+ def test_processor_case_2(self):
+ # case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ words = ["hello", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["input_ids", "bbox", "attention_mask", "pixel_values"]
+ actual_keys = list(input_processor.keys())
+ for key in expected_keys:
+ self.assertIn(key, actual_keys)
+
+ # verify input_ids
+ expected_decoding = " hello world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " hello world"
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [
+ [0, 0, 0, 0],
+ [3, 2, 5, 1],
+ [6, 7, 4, 2],
+ [3, 9, 2, 4],
+ [1, 1, 2, 3],
+ [1, 1, 2, 3],
+ [0, 0, 0, 0],
+ ]
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ @slow
+ def test_processor_case_3(self):
+ # case 3: token classification (training), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ words = ["weirdly", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ word_labels = [1, 2]
+ input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " weirdly world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify labels
+ expected_labels = [-100, 1, -100, 2, -100]
+ self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels)
+
+ # batched
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ word_labels = [[1, 2], [6, 3, 10, 2]]
+ input_processor = processor(
+ images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt"
+ )
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " my name is niels"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [
+ [0, 0, 0, 0],
+ [3, 2, 5, 1],
+ [6, 7, 4, 2],
+ [3, 9, 2, 4],
+ [1, 1, 2, 3],
+ [1, 1, 2, 3],
+ [0, 0, 0, 0],
+ ]
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ # verify labels
+ expected_labels = [-100, 6, 3, 10, 2, -100, -100]
+ self.assertListEqual(input_processor.labels[1].tolist(), expected_labels)
+
+ @slow
+ def test_processor_case_4(self):
+ # case 4: visual question answering (inference), apply_ocr=True
+
+ feature_extractor = LayoutLMv3FeatureExtractor()
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ question = "What's his name?"
+ input_processor = processor(images[0], question, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ # fmt: off
+ expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
+ # fmt: on
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ questions = ["How old is he?", "what's the time"]
+ input_processor = processor(
+ images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt"
+ )
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ # this was obtained with Tesseract 4.1.1
+ expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ # fmt: off
+ expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [74, 136, 161, 158], [0, 0, 0, 0]] # noqa: E231
+ # fmt: on
+ self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
+
+ @slow
+ def test_processor_case_5(self):
+ # case 5: visual question answering (inference), apply_ocr=False
+
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
+ tokenizers = self.get_tokenizers
+ images = self.get_images
+
+ for tokenizer in tokenizers:
+ processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # not batched
+ question = "What's his name?"
+ words = ["hello", "world"]
+ boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
+ input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " What's his name? hello world"
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # batched
+ questions = ["How old is he?", "what's the time"]
+ words = [["hello", "world"], ["my", "name", "is", "niels"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
+ input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
+
+ # verify keys
+ expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
+ actual_keys = sorted(list(input_processor.keys()))
+ self.assertListEqual(actual_keys, expected_keys)
+
+ # verify input_ids
+ expected_decoding = " How old is he? hello world"
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ expected_decoding = " what's the time my name is niels"
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
+ self.assertSequenceEqual(decoding, expected_decoding)
+
+ # verify bbox
+ expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [0, 0, 0, 0]]
+ self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
new file mode 100644
index 00000000000000..ae12129e787f9d
--- /dev/null
+++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
@@ -0,0 +1,2345 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import re
+import shutil
+import tempfile
+import unittest
+from typing import List
+
+from transformers import AddedToken, LayoutLMv3TokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
+from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer
+from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow
+
+from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings
+
+
+@require_tokenizers
+@require_pandas
+class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+ tokenizer_class = LayoutLMv3Tokenizer
+ rust_tokenizer_class = LayoutLMv3TokenizerFast
+ test_rust_tokenizer = True
+ # determined by the tokenization algortihm and the way it's decoded by the fast tokenizers
+ space_between_special_tokens = False
+ test_seq2seq = False
+ from_pretrained_kwargs = {"cls_token": ""}
+
+ def get_words_and_boxes(self):
+ words = ["lower", "newer"]
+ boxes = [[423, 237, 440, 251], [427, 272, 441, 287]]
+
+ return words, boxes
+
+ def get_words_and_boxes_batch(self):
+ words = [["lower", "newer"], ["new", "low"]]
+ boxes = [
+ [[423, 237, 440, 251], [427, 272, 441, 287]],
+ [[961, 885, 992, 912], [256, 38, 330, 58]],
+ ]
+
+ return words, boxes
+
+ def get_question_words_and_boxes(self):
+ question = "what's his name?"
+ words = ["lower", "newer"]
+ boxes = [[423, 237, 440, 251], [427, 272, 441, 287]]
+
+ return question, words, boxes
+
+ def get_question_words_and_boxes_batch(self):
+ questions = ["what's his name?", "how is he called?"]
+ words = [["lower", "newer"], ["newer", "lower"]]
+ boxes = [
+ [[423, 237, 440, 251], [427, 272, 441, 287]],
+ [[256, 38, 330, 58], [256, 38, 330, 58]],
+ ]
+
+ return questions, words, boxes
+
+ def setUp(self):
+ super().setUp()
+
+ # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
+ vocab = [
+ "l",
+ "o",
+ "w",
+ "e",
+ "r",
+ "s",
+ "t",
+ "i",
+ "d",
+ "n",
+ "\u0120",
+ "\u0120l",
+ "\u0120n",
+ "\u0120lo",
+ "\u0120low",
+ "er",
+ "\u0120lowest",
+ "\u0120newer",
+ "\u0120wider",
+ "",
+ ]
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+ merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
+ self.special_tokens_map = {"unk_token": ""}
+
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+ with open(self.merges_file, "w", encoding="utf-8") as fp:
+ fp.write("\n".join(merges))
+
+ def get_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return LayoutLMv3TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "lower newer"
+ output_text = "lower newer"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
+ text = "lower newer"
+ bpe_tokens = ["Ä low", "er", "Ä ", "n", "e", "w", "er"]
+ tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
+ self.assertListEqual(tokens, bpe_tokens)
+
+ input_tokens = tokens + [tokenizer.unk_token]
+ input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutlmv3-base")
+
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ text = tokenizer.encode(
+ question.split(),
+ boxes=[tokenizer.pad_token_box for _ in range(len(question.split()))],
+ add_special_tokens=False,
+ )
+ text_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_pair == [0] + text + [2] + [2] + text_2 + [2]
+
+ def test_add_special_tokens(self):
+ tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ special_token = "[SPECIAL_TOKEN]"
+ special_token_box = [1000, 1000, 1000, 1000]
+
+ tokenizer.add_special_tokens({"cls_token": special_token})
+ encoded_special_token = tokenizer.encode(
+ [special_token], boxes=[special_token_box], add_special_tokens=False
+ )
+ self.assertEqual(len(encoded_special_token), 1)
+
+ decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True)
+ self.assertTrue(special_token not in decoded)
+
+ def test_add_tokens_tokenizer(self):
+ tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ vocab_size = tokenizer.vocab_size
+ all_size = len(tokenizer)
+
+ self.assertNotEqual(vocab_size, 0)
+
+ # We usually have added tokens from the start in tests because our vocab fixtures are
+ # smaller than the original vocabs - let's not assert this
+ # self.assertEqual(vocab_size, all_size)
+
+ new_toks = ["aaaaa", "bbbbbb", "cccccccccdddddddd"]
+ added_toks = tokenizer.add_tokens(new_toks)
+ vocab_size_2 = tokenizer.vocab_size
+ all_size_2 = len(tokenizer)
+
+ self.assertNotEqual(vocab_size_2, 0)
+ self.assertEqual(vocab_size, vocab_size_2)
+ self.assertEqual(added_toks, len(new_toks))
+ self.assertEqual(all_size_2, all_size + len(new_toks))
+
+ words = "aaaaa bbbbbb low cccccccccdddddddd l".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ self.assertGreaterEqual(len(tokens), 4)
+ self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
+
+ new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
+ added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
+ vocab_size_3 = tokenizer.vocab_size
+ all_size_3 = len(tokenizer)
+
+ self.assertNotEqual(vocab_size_3, 0)
+ self.assertEqual(vocab_size, vocab_size_3)
+ self.assertEqual(added_toks_2, len(new_toks_2))
+ self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
+
+ words = ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ tokens = tokenizer.encode(
+ words,
+ boxes=boxes,
+ add_special_tokens=False,
+ )
+
+ self.assertGreaterEqual(len(tokens), 6)
+ self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[0], tokens[1])
+ self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
+ self.assertGreater(tokens[-2], tokens[-3])
+ self.assertEqual(tokens[0], tokenizer.eos_token_id)
+ self.assertEqual(tokens[-2], tokenizer.pad_token_id)
+
+ @require_tokenizers
+ def test_encode_decode_with_spaces(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)]
+ tokenizer.add_tokens(new_toks)
+ input = "[ABC][DEF][ABC][DEF]"
+ if self.space_between_special_tokens:
+ output = "[ABC] [DEF] [ABC] [DEF]"
+ else:
+ output = input
+ encoded = tokenizer.encode(input.split(), boxes=boxes, add_special_tokens=False)
+ decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
+ self.assertIn(decoded, [output, output.lower()])
+
+ @unittest.skip("Not implemented")
+ def test_right_and_left_truncation(self):
+ pass
+
+ def test_encode_plus_with_padding(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ padding_size = 10
+ padding_idx = tokenizer.pad_token_id
+
+ encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_special_tokens_mask=True)
+ input_ids = encoded_sequence["input_ids"]
+ special_tokens_mask = encoded_sequence["special_tokens_mask"]
+ sequence_length = len(input_ids)
+
+ # Test 'longest' and 'no_padding' don't do anything
+ tokenizer.padding_side = "right"
+
+ not_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ padding=False,
+ return_special_tokens_mask=True,
+ )
+ not_padded_input_ids = not_padded_sequence["input_ids"]
+
+ not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
+ not_padded_sequence_length = len(not_padded_input_ids)
+
+ self.assertTrue(sequence_length == not_padded_sequence_length)
+ self.assertTrue(input_ids == not_padded_input_ids)
+ self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
+
+ not_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ padding=False,
+ return_special_tokens_mask=True,
+ )
+ not_padded_input_ids = not_padded_sequence["input_ids"]
+
+ not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
+ not_padded_sequence_length = len(not_padded_input_ids)
+
+ self.assertTrue(sequence_length == not_padded_sequence_length)
+ self.assertTrue(input_ids == not_padded_input_ids)
+ self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
+
+ # Test right padding
+ tokenizer.padding_side = "right"
+
+ right_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=sequence_length + padding_size,
+ padding="max_length",
+ return_special_tokens_mask=True,
+ )
+ right_padded_input_ids = right_padded_sequence["input_ids"]
+
+ right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"]
+ right_padded_sequence_length = len(right_padded_input_ids)
+
+ self.assertTrue(sequence_length + padding_size == right_padded_sequence_length)
+ self.assertTrue(input_ids + [padding_idx] * padding_size == right_padded_input_ids)
+ self.assertTrue(special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask)
+
+ # Test left padding
+ tokenizer.padding_side = "left"
+ left_padded_sequence = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=sequence_length + padding_size,
+ padding="max_length",
+ return_special_tokens_mask=True,
+ )
+ left_padded_input_ids = left_padded_sequence["input_ids"]
+ left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"]
+ left_padded_sequence_length = len(left_padded_input_ids)
+
+ self.assertTrue(sequence_length + padding_size == left_padded_sequence_length)
+ self.assertTrue([padding_idx] * padding_size + input_ids == left_padded_input_ids)
+ self.assertTrue([1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask)
+
+ if "token_type_ids" in tokenizer.model_input_names:
+ token_type_ids = encoded_sequence["token_type_ids"]
+ left_padded_token_type_ids = left_padded_sequence["token_type_ids"]
+ right_padded_token_type_ids = right_padded_sequence["token_type_ids"]
+
+ assert token_type_ids + [0] * padding_size == right_padded_token_type_ids
+ assert [0] * padding_size + token_type_ids == left_padded_token_type_ids
+
+ if "attention_mask" in tokenizer.model_input_names:
+ attention_mask = encoded_sequence["attention_mask"]
+ right_padded_attention_mask = right_padded_sequence["attention_mask"]
+ left_padded_attention_mask = left_padded_sequence["attention_mask"]
+
+ self.assertTrue(attention_mask + [0] * padding_size == right_padded_attention_mask)
+ self.assertTrue([0] * padding_size + attention_mask == left_padded_attention_mask)
+
+ def test_internal_consistency(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ tokens = []
+ for word in words:
+ tokens.extend(tokenizer.tokenize(word))
+ ids = tokenizer.convert_tokens_to_ids(tokens)
+ ids_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ self.assertListEqual(ids, ids_2)
+
+ tokens_2 = tokenizer.convert_ids_to_tokens(ids)
+ self.assertNotEqual(len(tokens_2), 0)
+ text_2 = tokenizer.decode(ids)
+ self.assertIsInstance(text_2, str)
+
+ output_text = " lower newer"
+ self.assertEqual(text_2, output_text)
+
+ def test_mask_output(self):
+ tokenizers = self.get_tokenizers(fast=False, do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+
+ if (
+ tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer"
+ and "token_type_ids" in tokenizer.model_input_names
+ ):
+ information = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True)
+ sequences, mask = information["input_ids"], information["token_type_ids"]
+ self.assertEqual(len(sequences), len(mask))
+
+ def test_number_of_added_tokens(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ # test 1: single sequence
+ words, boxes = self.get_words_and_boxes()
+
+ sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ attached_sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+
+ # Method is implemented (e.g. not GPT-2)
+ if len(attached_sequences) != 2:
+ self.assertEqual(
+ tokenizer.num_special_tokens_to_add(pair=False), len(attached_sequences) - len(sequences)
+ )
+
+ # test 2: two sequences
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=False)
+ attached_sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=True)
+
+ # Method is implemented (e.g. not GPT-2)
+ if len(attached_sequences) != 2:
+ self.assertEqual(
+ tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences)
+ )
+
+ def test_padding_to_max_length(self):
+ """We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated"""
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ padding_size = 10
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ padding_idx = tokenizer.pad_token_id
+
+ # Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "right"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ # FIXME: the next line should be padding(max_length) to avoid warning
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, pad_to_max_length=True
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
+
+ # Check that nothing is done when a maximum length is not specified
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes, pad_to_max_length=True)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ def test_padding(self, max_length=50):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
+ pad_token_id = tokenizer_p.pad_token_id
+
+ # Encode - Simple input
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.encode(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode(words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
+
+ # Encode - Pair input
+ question, words, boxes = self.get_question_words_and_boxes()
+ input_r = tokenizer_r.encode(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
+ input_r = tokenizer_r.encode(question, words, boxes=boxes, padding=True)
+ input_p = tokenizer_p.encode(question, words, boxes=boxes, padding="longest")
+ self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
+
+ # Encode_plus - Simple input
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True)
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length")
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length")
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Encode_plus - Pair input
+ question, words, boxes = self.get_question_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ input_p = tokenizer_p.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True
+ )
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ input_p = tokenizer_p.encode_plus(
+ question, words, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+ input_r = tokenizer_r.encode_plus(question, words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.encode_plus(question, words, boxes=boxes, padding=True)
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+ self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
+
+ # Batch_encode_plus - Simple input
+ words, boxes = self.get_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ pad_to_max_length=True,
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="max_length",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding="longest",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ words,
+ boxes=boxes,
+ max_length=max_length,
+ padding=True,
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes, padding="longest")
+ input_p = tokenizer_p.batch_encode_plus(words, boxes=boxes, padding=True)
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Batch_encode_plus - Pair input
+ questions, words, boxes = self.get_question_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ input_r = tokenizer_r.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ padding=True,
+ )
+ input_p = tokenizer_p.batch_encode_plus(
+ list(zip(questions, words)),
+ is_pair=True,
+ boxes=boxes,
+ padding="longest",
+ )
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Using pad on single examples after tokenization
+ words, boxes = self.get_words_and_boxes()
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.pad(input_p)
+
+ self.assert_padded_input_match(
+ input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
+ )
+
+ # Using pad on single examples after tokenization
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
+
+ # Using pad after tokenization
+ words, boxes = self.get_words_and_boxes_batch()
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_r = tokenizer_r.pad(input_r)
+
+ input_p = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_p = tokenizer_r.pad(input_p)
+
+ self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
+
+ # Using pad after tokenization
+ words, boxes = self.get_words_and_boxes_batch()
+ input_r = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
+
+ input_p = tokenizer_r.batch_encode_plus(
+ words,
+ boxes=boxes,
+ )
+ input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
+
+ self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
+
+ def test_call(self):
+ # Tests that all call wrap to encode_plus and batch_encode_plus
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Test not batched
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ # Test not batched pairs
+ question, words, boxes = self.get_question_words_and_boxes()
+ encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ # Test batched
+ words, boxes = self.get_words_and_boxes_batch()
+ encoded_sequences_1 = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes)
+ encoded_sequences_2 = tokenizer(words, boxes=boxes)
+ self.assertEqual(encoded_sequences_1, encoded_sequences_2)
+
+ def test_batch_encode_plus_batch_sequence_length(self):
+ # Tests that all encoded values have the correct size
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ encoded_sequences = [
+ tokenizer.encode_plus(words_example, boxes=boxes_example)
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes, padding=False)
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ maximum_length = len(
+ max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len)
+ )
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences_padded = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=maximum_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+
+ encoded_sequences_batch_padded = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=True
+ )
+ self.assertListEqual(
+ encoded_sequences_padded,
+ self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded),
+ )
+
+ # check 'longest' is unsensitive to a max length
+ encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=True
+ )
+ encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding="longest"
+ )
+ for key in encoded_sequences_batch_padded_1.keys():
+ self.assertListEqual(
+ encoded_sequences_batch_padded_1[key],
+ encoded_sequences_batch_padded_2[key],
+ )
+
+ # check 'no_padding' is unsensitive to a max length
+ encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, padding=False
+ )
+ encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding=False
+ )
+ for key in encoded_sequences_batch_padded_1.keys():
+ self.assertListEqual(
+ encoded_sequences_batch_padded_1[key],
+ encoded_sequences_batch_padded_2[key],
+ )
+
+ @unittest.skip("batch_encode_plus does not handle overflowing tokens.")
+ def test_batch_encode_plus_overflowing_tokens(self):
+ pass
+
+ def test_batch_encode_plus_padding(self):
+ # Test that padded sequences are equivalent between batch_encode_plus and encode_plus
+
+ # Right padding tests
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ max_length = 100
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=max_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ # Left padding tests
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ tokenizer.padding_side = "left"
+ words, boxes = self.get_words_and_boxes_batch()
+
+ max_length = 100
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, words)
+
+ encoded_sequences = [
+ tokenizer.encode_plus(
+ words_example, boxes=boxes_example, max_length=max_length, padding="max_length"
+ )
+ for words_example, boxes_example in zip(words, boxes)
+ ]
+ encoded_sequences_batch = tokenizer.batch_encode_plus(
+ words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length"
+ )
+ self.assertListEqual(
+ encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
+ )
+
+ def test_padding_to_multiple_of(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ if tokenizer.pad_token is None:
+ self.skipTest("No padding token.")
+ else:
+ words, boxes = self.get_words_and_boxes()
+
+ # empty_tokens = tokenizer([""], [[]], padding=True, pad_to_multiple_of=8)
+ normal_tokens = tokenizer(words, boxes=boxes, padding=True, pad_to_multiple_of=8)
+ # for key, value in empty_tokens.items():
+ # self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+ for key, value in normal_tokens.items():
+ self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ normal_tokens = tokenizer(words, boxes=boxes, pad_to_multiple_of=8)
+ for key, value in normal_tokens.items():
+ self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ # Should also work with truncation
+ normal_tokens = tokenizer(words, boxes=boxes, padding=True, truncation=True, pad_to_multiple_of=8)
+ for key, value in normal_tokens.items():
+ self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
+
+ # truncation to something which is not a multiple of pad_to_multiple_of raises an error
+ self.assertRaises(
+ ValueError,
+ tokenizer.__call__,
+ words,
+ boxes=boxes,
+ padding=True,
+ truncation=True,
+ max_length=12,
+ pad_to_multiple_of=8,
+ )
+
+ def test_tokenizer_slow_store_full_signature(self):
+ signature = inspect.signature(self.tokenizer_class.__init__)
+ tokenizer = self.get_tokenizer()
+
+ for parameter_name, parameter in signature.parameters.items():
+ if parameter.default != inspect.Parameter.empty:
+ self.assertIn(parameter_name, tokenizer.init_kwargs)
+
+ def test_build_inputs_with_special_tokens(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ # Input tokens id
+ words, boxes = self.get_words_and_boxes()
+ input_simple = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False)
+ input_pair = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False)
+
+ # Generate output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
+ self.assertEqual(output_p, output_r)
+
+ # Generate pair output
+ output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
+ output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
+ self.assertEqual(output_p, output_r)
+
+ def test_special_tokens_mask_input_pairs(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ encoded_sequence_dict = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ return_special_tokens_mask=True,
+ # add_prefix_space=False,
+ )
+ encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
+ special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
+ self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
+
+ filtered_sequence = [
+ (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
+ ]
+ filtered_sequence = [x for x in filtered_sequence if x is not None]
+ self.assertEqual(encoded_sequence, filtered_sequence)
+
+ def test_special_tokens_mask(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ # Testing single inputs
+ encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ encoded_sequence_dict = tokenizer.encode_plus(
+ words, boxes=boxes, add_special_tokens=True, return_special_tokens_mask=True
+ )
+ encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
+ special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
+ self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
+
+ filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
+ self.assertEqual(encoded_sequence, filtered_sequence)
+
+ def test_save_and_load_tokenizer(self):
+ # safety check on max_len default value so we are sure the test works
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ self.assertNotEqual(tokenizer.model_max_length, 42)
+
+ # Now let's start the test
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Isolate this from the other tests because we save additional tokens/etc
+ words, boxes = self.get_words_and_boxes()
+ tmpdirname = tempfile.mkdtemp()
+
+ before_tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ before_vocab = tokenizer.get_vocab()
+ tokenizer.save_pretrained(tmpdirname)
+
+ after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
+ after_tokens = after_tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ after_vocab = after_tokenizer.get_vocab()
+ self.assertListEqual(before_tokens, after_tokens)
+ self.assertDictEqual(before_vocab, after_vocab)
+
+ shutil.rmtree(tmpdirname)
+
+ def test_right_and_left_padding(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ sequence = "Sequence"
+ padding_size = 10
+
+ # check correct behaviour if no pad_token_id exists and add it eventually
+ self._check_no_pad_token_padding(tokenizer, sequence)
+
+ padding_idx = tokenizer.pad_token_id
+
+ # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "right"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length"
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
+
+ # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
+ tokenizer.padding_side = "left"
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+ padded_sequence = tokenizer.encode(
+ words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length"
+ )
+ padded_sequence_length = len(padded_sequence)
+ assert sequence_length + padding_size == padded_sequence_length
+ assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
+
+ # RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
+ encoded_sequence = tokenizer.encode(words, boxes=boxes)
+ sequence_length = len(encoded_sequence)
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes, padding=True)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ tokenizer.padding_side = "left"
+ padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding="longest")
+ padded_sequence_left_length = len(padded_sequence_left)
+ assert sequence_length == padded_sequence_left_length
+ assert encoded_sequence == padded_sequence_left
+
+ tokenizer.padding_side = "right"
+ padded_sequence_right = tokenizer.encode(words, boxes=boxes)
+ padded_sequence_right_length = len(padded_sequence_right)
+ assert sequence_length == padded_sequence_right_length
+ assert encoded_sequence == padded_sequence_right
+
+ tokenizer.padding_side = "left"
+ padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding=False)
+ padded_sequence_left_length = len(padded_sequence_left)
+ assert sequence_length == padded_sequence_left_length
+ assert encoded_sequence == padded_sequence_left
+
+ def test_token_type_ids(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ # test 1: single sequence
+ words, boxes = self.get_words_and_boxes()
+
+ output = tokenizer(words, boxes=boxes, return_token_type_ids=True)
+
+ # Assert that the token type IDs have the same length as the input IDs
+ self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
+
+ # Assert that the token type IDs have the same length as the attention mask
+ self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
+
+ self.assertIn(0, output["token_type_ids"])
+ self.assertNotIn(1, output["token_type_ids"])
+
+ # test 2: two sequences (question + words)
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ output = tokenizer(question, words, boxes, return_token_type_ids=True)
+
+ # Assert that the token type IDs have the same length as the input IDs
+ self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
+
+ # Assert that the token type IDs have the same length as the attention mask
+ self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
+
+ self.assertIn(0, output["token_type_ids"])
+
+ def test_offsets_mapping(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ text = ["a", "wonderful", "test"]
+ boxes = [[1, 8, 12, 20] for _ in range(len(text))]
+
+ # No pair
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text,
+ boxes=boxes,
+ return_special_tokens_mask=True,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(False)
+ offsets = tokens_with_offsets["offset_mapping"]
+
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+
+ # Pairs
+ text = "what's his name"
+ pair = ["a", "wonderful", "test"]
+ boxes = [[1, 8, 12, 20] for _ in range(len(pair))]
+ tokens_with_offsets = tokenizer_r.encode_plus(
+ text,
+ pair,
+ boxes=boxes,
+ return_special_tokens_mask=True,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+ added_tokens = tokenizer_r.num_special_tokens_to_add(True)
+ offsets = tokens_with_offsets["offset_mapping"]
+
+ # Assert there is the same number of tokens and offsets
+ self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
+
+ # Assert there is online added_tokens special_tokens
+ self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
+
+ @require_torch
+ @slow
+ def test_torch_encode_plus_sent_to_model(self):
+ import torch
+
+ from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
+
+ MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
+
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
+ return
+
+ config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
+ config = config_class()
+
+ if config.is_encoder_decoder or config.pad_token_id is None:
+ return
+
+ model = model_class(config)
+
+ # Make sure the model contains at least the full vocabulary size in its embedding matrix
+ is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight")
+ assert (
+ (model.get_input_embeddings().weight.shape[0] >= len(tokenizer))
+ if is_using_common_embeddings
+ else True
+ )
+
+ # Build sequence
+ words, boxes = self.get_words_and_boxes()
+ encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_tensors="pt")
+ batch_encoded_sequence = tokenizer.batch_encode_plus(
+ [words, words], boxes=[boxes, boxes], return_tensors="pt"
+ )
+
+ # We add dummy pixel_values keys (as LayoutLMv3 actually also requires a feature extractor
+ # to prepare the image input)
+ encoded_sequence["pixel_values"] = torch.randn(1, 3, 224, 224)
+ batch_encoded_sequence["pixel_values"] = torch.randn(2, 3, 224, 224)
+
+ # This should not fail
+ with torch.no_grad(): # saves some time
+ model(**encoded_sequence)
+ model(**batch_encoded_sequence)
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ words, boxes = self.get_words_and_boxes()
+
+ ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+ rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=True)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_tokenization_python_rust_equals(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ words, boxes = self.get_words_and_boxes()
+
+ # Ensure basic input match
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes)
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key])
+
+ input_pairs_p = tokenizer_p.encode_plus(words, boxes=boxes)
+ input_pairs_r = tokenizer_r.encode_plus(words, boxes=boxes)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
+
+ words = ["hello" for _ in range(1000)]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(1000)]
+
+ # Ensure truncation match
+ input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=512, truncation=True)
+ input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=512, truncation=True)
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key])
+
+ # Ensure truncation with stride match
+ input_p = tokenizer_p.encode_plus(
+ words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
+ input_r = tokenizer_r.encode_plus(
+ words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
+ )
+
+ for key in filter(
+ lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys()
+ ):
+ self.assertSequenceEqual(input_p[key], input_r[key][0])
+
+ def test_embeded_special_tokens(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ words, boxes = self.get_words_and_boxes()
+ tokens_r = tokenizer_r.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ )
+ tokens_p = tokenizer_p.encode_plus(
+ words,
+ boxes=boxes,
+ add_special_tokens=True,
+ )
+
+ for key in tokens_p.keys():
+ self.assertEqual(tokens_r[key], tokens_p[key])
+
+ if "token_type_ids" in tokens_r:
+ self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
+
+ tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
+ tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
+ self.assertSequenceEqual(tokens_r, tokens_p)
+
+ def test_compare_add_special_tokens(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
+
+ words, boxes = self.get_words_and_boxes()
+ # tokenize()
+ no_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=False)
+ with_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=True)
+ self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
+
+ # encode()
+ no_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=True)
+ self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
+
+ # encode_plus()
+ no_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ self.assertEqual(
+ len(no_special_tokens[key]),
+ len(with_special_tokens[key]) - simple_num_special_tokens_to_add,
+ )
+
+ # # batch_encode_plus
+ words, boxes = self.get_words_and_boxes_batch()
+
+ no_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=False)
+ with_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=True)
+ for key in no_special_tokens.keys():
+ for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
+ self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
+
+ @slow
+ def test_layoutlmv3_truncation_integration_test(self):
+ words, boxes = self.get_words_and_boxes()
+
+ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", model_max_length=512)
+
+ for i in range(12, 512):
+ new_encoded_inputs = tokenizer.encode(words, boxes=boxes, max_length=i, truncation=True)
+
+ # Ensure that the input IDs are less than the max length defined.
+ self.assertLessEqual(len(new_encoded_inputs), i)
+
+ tokenizer.model_max_length = 20
+ new_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True)
+ dropped_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True)
+
+ # Ensure that the input IDs are still truncated when no max_length is specified
+ self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
+ self.assertLessEqual(len(new_encoded_inputs), 20)
+
+ @is_pt_tf_cross_test
+ def test_batch_encode_plus_tensors(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes_batch()
+
+ # A Tensor cannot be build by sequences which are not the same size
+ self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="pt")
+ self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="tf")
+
+ if tokenizer.pad_token_id is None:
+ self.assertRaises(
+ ValueError,
+ tokenizer.batch_encode_plus,
+ words,
+ boxes=boxes,
+ padding=True,
+ return_tensors="pt",
+ )
+ self.assertRaises(
+ ValueError,
+ tokenizer.batch_encode_plus,
+ words,
+ boxes=boxes,
+ padding="longest",
+ return_tensors="tf",
+ )
+ else:
+ pytorch_tensor = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True, return_tensors="pt")
+ tensorflow_tensor = tokenizer.batch_encode_plus(
+ words, boxes=boxes, padding="longest", return_tensors="tf"
+ )
+ encoded_sequences = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True)
+
+ for key in encoded_sequences.keys():
+ pytorch_value = pytorch_tensor[key].tolist()
+ tensorflow_value = tensorflow_tensor[key].numpy().tolist()
+ encoded_value = encoded_sequences[key]
+
+ self.assertEqual(pytorch_value, tensorflow_value, encoded_value)
+
+ def test_sequence_ids(self):
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ if not tokenizer.is_fast:
+ continue
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ seq_0 = "Test this method."
+ seq_1 = ["With", "these", "inputs."]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(seq_1))]
+
+ # We want to have sequence 0 and sequence 1 are tagged
+ # respectively with 0 and 1 token_ids
+ # (regardless of whether the model use token type ids)
+ # We use this assumption in the QA pipeline among other place
+ output = tokenizer(seq_0.split(), boxes=boxes)
+ self.assertIn(0, output.sequence_ids())
+
+ output = tokenizer(seq_0, seq_1, boxes=boxes)
+ self.assertIn(0, output.sequence_ids())
+ self.assertIn(1, output.sequence_ids())
+
+ if tokenizer.num_special_tokens_to_add(pair=True):
+ self.assertIn(None, output.sequence_ids())
+
+ def test_special_tokens_initialization(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ added_tokens = [AddedToken("", lstrip=True)]
+
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+ words = "Hey this is a token".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+ r_output = tokenizer_r.encode(words, boxes=boxes)
+
+ special_token_id = tokenizer_r.encode(
+ [""], boxes=[1000, 1000, 1000, 1000], add_special_tokens=False
+ )[0]
+
+ self.assertTrue(special_token_id in r_output)
+
+ if self.test_slow_tokenizer:
+ tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
+ )
+ tokenizer_p = self.tokenizer_class.from_pretrained(
+ pretrained_name, additional_special_tokens=added_tokens, **kwargs
+ )
+
+ words = "Hey this is a token".split()
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+
+ p_output = tokenizer_p.encode(words, boxes=boxes)
+ cr_output = tokenizer_cr.encode(words, boxes=boxes)
+
+ self.assertEqual(p_output, r_output)
+ self.assertEqual(cr_output, r_output)
+ self.assertTrue(special_token_id in p_output)
+ self.assertTrue(special_token_id in cr_output)
+
+ def test_training_new_tokenizer(self):
+ # This feature only exists for fast tokenizers
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_rust_tokenizer()
+ new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
+
+ # Test we can use the new tokenizer with something not seen during training
+ text = [["this", "is", "the"], ["how", "are", "you"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8], [1, 3, 4, 8]], [[5, 6, 7, 8], [4, 5, 6, 7], [3, 9, 2, 7]]]
+ inputs = new_tokenizer(text, boxes=boxes)
+ self.assertEqual(len(inputs["input_ids"]), 2)
+ decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
+ expected_result = " this is the"
+
+ if tokenizer.backend_tokenizer.normalizer is not None:
+ expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
+ self.assertEqual(expected_result, decoded_input)
+
+ # We check that the parameters of the tokenizer remained the same
+ # Check we have the same number of added_tokens for both pair and non-pair inputs.
+ self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
+ self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
+
+ # Check we have the correct max_length for both pair and non-pair inputs.
+ self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
+ self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
+
+ # Assert the set of special tokens match as we didn't ask to change them
+ self.assertSequenceEqual(
+ tokenizer.all_special_tokens_extended,
+ new_tokenizer.all_special_tokens_extended,
+ )
+
+ self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
+
+ def test_training_new_tokenizer_with_special_tokens_change(self):
+ # This feature only exists for fast tokenizers
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_rust_tokenizer()
+ # Test with a special tokens map
+ class_signature = inspect.signature(tokenizer.__class__)
+ if "cls_token" in class_signature.parameters:
+ new_tokenizer = tokenizer.train_new_from_iterator(
+ SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: ""}
+ )
+ cls_id = new_tokenizer.get_vocab()[""]
+ self.assertEqual(new_tokenizer.cls_token, "")
+ self.assertEqual(new_tokenizer.cls_token_id, cls_id)
+
+ # Create a new mapping from the special tokens defined in the original tokenizer
+ special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
+ special_tokens_list.remove("additional_special_tokens")
+ special_tokens_map = {}
+ for token in special_tokens_list:
+ # Get the private one to avoid unnecessary warnings.
+ if getattr(tokenizer, f"_{token}") is not None:
+ special_token = getattr(tokenizer, token)
+ special_tokens_map[special_token] = f"{special_token}a"
+
+ # Train new tokenizer
+ new_tokenizer = tokenizer.train_new_from_iterator(
+ SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
+ )
+
+ # Check the changes
+ for token in special_tokens_list:
+ # Get the private one to avoid unnecessary warnings.
+ if getattr(tokenizer, f"_{token}") is None:
+ continue
+ special_token = getattr(tokenizer, token)
+ if special_token in special_tokens_map:
+ new_special_token = getattr(new_tokenizer, token)
+ self.assertEqual(special_tokens_map[special_token], new_special_token)
+
+ new_id = new_tokenizer.get_vocab()[new_special_token]
+ self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
+
+ # Check if the AddedToken / string format has been kept
+ for special_token in tokenizer.all_special_tokens_extended:
+ if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
+ # The special token must appear identically in the list of the new tokenizer.
+ self.assertTrue(
+ special_token in new_tokenizer.all_special_tokens_extended,
+ f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
+ )
+ elif isinstance(special_token, AddedToken):
+ # The special token must appear in the list of the new tokenizer as an object of type AddedToken with
+ # the same parameters as the old AddedToken except the content that the user has requested to change.
+ special_token_str = special_token.content
+ new_special_token_str = special_tokens_map[special_token_str]
+
+ find = False
+ for candidate in new_tokenizer.all_special_tokens_extended:
+ if (
+ isinstance(candidate, AddedToken)
+ and candidate.content == new_special_token_str
+ and candidate.lstrip == special_token.lstrip
+ and candidate.rstrip == special_token.rstrip
+ and candidate.normalized == special_token.normalized
+ and candidate.single_word == special_token.single_word
+ ):
+ find = True
+ break
+ self.assertTrue(
+ find,
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
+ )
+ elif special_token not in special_tokens_map:
+ # The special token must appear identically in the list of the new tokenizer.
+ self.assertTrue(
+ special_token in new_tokenizer.all_special_tokens_extended,
+ f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
+ )
+
+ else:
+ # The special token must appear in the list of the new tokenizer as an object of type string.
+ self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended)
+
+ # Test we can use the new tokenizer with something not seen during training
+ words = [["this", "is"], ["hello", "š¤"]]
+ boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]]
+ inputs = new_tokenizer(words, boxes=boxes)
+ self.assertEqual(len(inputs["input_ids"]), 2)
+ decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
+ expected_result = " this is"
+
+ if tokenizer.backend_tokenizer.normalizer is not None:
+ expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
+ self.assertEqual(expected_result, decoded_input)
+
+ def test_prepare_for_model(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ # only test prepare_for_model for the slow tokenizer
+ if tokenizer.__class__.__name__ == "LayoutLMv3TokenizerFast":
+ continue
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ words, boxes = self.get_words_and_boxes()
+ prepared_input_dict = tokenizer.prepare_for_model(words, boxes=boxes, add_special_tokens=True)
+
+ input_dict = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True)
+
+ self.assertEqual(input_dict, prepared_input_dict)
+
+ def test_padding_different_model_input_name(self):
+ if not self.test_slow_tokenizer:
+ # as we don't have a slow version, we can't compare the outputs between slow and fast versions
+ return
+
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
+ pad_token_id = tokenizer_p.pad_token_id
+
+ words, boxes = self.get_words_and_boxes_batch()
+
+ input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes)
+ input_p = tokenizer_r.batch_encode_plus(words, boxes=boxes)
+
+ # rename encoded batch to "inputs"
+ input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]]
+ del input_r[tokenizer_r.model_input_names[0]]
+
+ input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]]
+ del input_p[tokenizer_p.model_input_names[0]]
+
+ # Renaming `input_ids` to `inputs`
+ tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:]
+ tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:]
+
+ input_r = tokenizer_r.pad(input_r, padding="longest")
+ input_p = tokenizer_r.pad(input_p, padding="longest")
+
+ max_length = len(input_p["inputs"][0])
+ self.assert_batch_padded_input_match(
+ input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs"
+ )
+
+ def test_batch_encode_dynamic_overflowing(self):
+ """
+ When calling batch_encode with multiple sequences, it can return different number of
+ overflowing encoding for each sequence:
+ [
+ Sequence 1: [Encoding 1, Encoding 2],
+ Sequence 2: [Encoding 1],
+ Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
+ ]
+ This needs to be padded so that it can represented as a tensor
+ """
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"):
+
+ if is_torch_available():
+ returned_tensor = "pt"
+ elif is_tf_available():
+ returned_tensor = "tf"
+ else:
+ returned_tensor = "jax"
+
+ # Single example
+ words = ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"]
+ boxes = [[i, i, i, i] for i in range(len(words))]
+ tokens = tokenizer.encode_plus(
+ words,
+ boxes=boxes,
+ max_length=6,
+ padding=True,
+ truncation=True,
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
+ )
+
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ if key != "bbox":
+ self.assertEqual(len(tokens[key].shape), 2)
+ else:
+ self.assertEqual(len(tokens[key].shape), 3)
+
+ # Batch of examples
+ # For these 2 examples, 3 training examples will be created
+ words_batched = [
+ ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"],
+ ["Very", "tiny", "input"],
+ ]
+ boxes_batched = [[[i, i, i, i] for i in range(len(words_item))] for words_item in words_batched]
+ tokens = tokenizer.batch_encode_plus(
+ words_batched,
+ boxes=boxes_batched,
+ max_length=6,
+ padding=True,
+ truncation="only_first",
+ return_tensors=returned_tensor,
+ return_overflowing_tokens=True,
+ )
+
+ for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
+ if key != "bbox":
+ self.assertEqual(len(tokens[key].shape), 2)
+ self.assertEqual(tokens[key].shape[-1], 6)
+ else:
+ self.assertEqual(len(tokens[key].shape), 3)
+ self.assertEqual(tokens[key].shape[-1], 4)
+
+ @unittest.skip("TO DO: overwrite this very extensive test.")
+ def test_alignement_methods(self):
+ pass
+
+ def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5):
+ toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
+ toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
+ toks = list(
+ filter(
+ lambda t: [t[0]]
+ == tokenizer.encode(t[1].split(" "), boxes=len(t[1]) * [[1, 1, 1, 1]], add_special_tokens=False),
+ toks,
+ )
+ )
+ if max_length is not None and len(toks) > max_length:
+ toks = toks[:max_length]
+ if min_length is not None and len(toks) < min_length and len(toks) > 0:
+ while len(toks) < min_length:
+ toks = toks + toks
+ # toks_str = [t[1] for t in toks]
+ toks_ids = [t[0] for t in toks]
+
+ # Ensure consistency
+ output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
+ if " " not in output_txt and len(toks_ids) > 1:
+ output_txt = (
+ tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ + " "
+ + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
+ )
+ if with_prefix_space:
+ output_txt = " " + output_txt
+ words = output_txt.split(" ")
+ boxes = [[i, i, i, i] for i in range(len(words))]
+ output_ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False)
+
+ return words, boxes, output_ids
+
+ def test_added_token_with_space_before(self):
+
+ tokenizer_s = self.get_tokenizer()
+ tokenizer_f = self.get_rust_tokenizer()
+
+ tokens_to_add = ["AAA", "bbb"]
+
+ words_with_space = [f" {token}" for token in tokens_to_add + tokenizer_s.unique_no_split_tokens]
+ words_without_space = tokens_to_add + tokenizer_s.unique_no_split_tokens
+ boxes = [[i, i, i, i] for i in range(len(words_with_space))]
+
+ tokens_to_add_formated = [
+ AddedToken(token, rstrip=True, lstrip=True, single_word=False) for token in tokens_to_add
+ ]
+ tokenizer_s.add_tokens(tokens_to_add_formated)
+ tokenizer_f.add_tokens(tokens_to_add_formated)
+
+ ids_s = tokenizer_s(words_with_space, boxes=boxes).input_ids
+ ids_f = tokenizer_f(words_with_space, boxes=boxes).input_ids
+
+ tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s)
+ tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f)
+
+ ids_s = tokenizer_s(words_without_space, boxes=boxes).input_ids
+ ids_f = tokenizer_f(words_without_space, boxes=boxes).input_ids
+
+ tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s)
+ tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f)
+
+ self.assertEqual(tokens_s, tokens_f)
+
+ def test_maximum_encoding_length_pair_input(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ # Build a sequence from our model's vocabulary
+ stride = 2
+ seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
+ question_0 = " ".join(map(str, seq_0))
+ if len(ids) <= 2 + stride:
+ seq_0 = (seq_0 + " ") * (2 + stride)
+ ids = None
+
+ seq0_tokens = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
+ seq0_input_ids = seq0_tokens["input_ids"]
+
+ self.assertGreater(len(seq0_input_ids), 2 + stride)
+ question_1 = "This is another sentence to be encoded."
+ seq_1 = ["what", "a", "weird", "test", "weirdly", "weird"]
+ boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
+ seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ if abs(len(seq0_input_ids) - len(seq1_tokens["input_ids"])) <= 2:
+ seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"]
+ seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False)
+ seq_1 = seq_1.split(" ")
+ boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
+ seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ seq1_input_ids = seq1_tokens["input_ids"]
+
+ self.assertGreater(len(seq1_input_ids), 2 + stride)
+
+ smallest = seq1_input_ids if len(seq0_input_ids) > len(seq1_input_ids) else seq0_input_ids
+
+ # We are not using the special tokens - a bit too hard to test all the tokenizers with this
+ # TODO try this again later
+ sequence = tokenizer(
+ question_0, seq_1, boxes=boxes_1, add_special_tokens=False
+ ) # , add_prefix_space=False)
+
+ # Test with max model input length
+ model_max_length = tokenizer.model_max_length
+ self.assertEqual(model_max_length, 100)
+ seq_2 = seq_0 * model_max_length
+ question_2 = " ".join(map(str, seq_2))
+ boxes_2 = boxes_0 * model_max_length
+ self.assertGreater(len(seq_2), model_max_length)
+
+ sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length1 = len(sequence1["input_ids"])
+ sequence2 = tokenizer(question_2, seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length2 = len(sequence2["input_ids"])
+ self.assertLess(total_length1, model_max_length, "Issue with the testing sequence, please update it.")
+ self.assertGreater(
+ total_length2, model_max_length, "Issue with the testing sequence, please update it."
+ )
+
+ # Simple
+ padding_strategies = (
+ [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
+ )
+ for padding_state in padding_strategies:
+ with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
+ for truncation_state in [True, "longest_first", "only_first"]:
+ with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
+ output = tokenizer(
+ question_2,
+ seq_1,
+ boxes=boxes_1,
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [question_2],
+ [seq_1],
+ boxes=[boxes_1],
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple
+ output = tokenizer(
+ question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation="only_second"
+ )
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation="only_second"
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple with no truncation
+ # Reset warnings
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(
+ question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation=False
+ )
+ self.assertNotEqual(len(output["input_ids"]), model_max_length)
+ self.assertNotEqual(len(output["bbox"]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(
+ [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation=False
+ )
+ self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertNotEqual(len(output["bbox"][0]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+ # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
+ truncated_first_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][:-2]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
+ )
+ truncated_second_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][:-2]
+ )
+ truncated_longest_sequence = (
+ truncated_first_sequence
+ if len(seq0_input_ids) > len(seq1_input_ids)
+ else truncated_second_sequence
+ )
+
+ overflow_first_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][-(2 + stride) :]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"]
+ )
+ overflow_second_sequence = (
+ tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"]
+ + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][-(2 + stride) :]
+ )
+ overflow_longest_sequence = (
+ overflow_first_sequence if len(seq0_input_ids) > len(seq1_input_ids) else overflow_second_sequence
+ )
+
+ bbox_first = [[0, 0, 0, 0]] * (len(seq0_input_ids) - 2)
+ bbox_first_sequence = bbox_first + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"]
+ overflowing_token_bbox_first_sequence_slow = [[0, 0, 0, 0]] * (2 + stride)
+ overflowing_token_bbox_first_sequence_fast = [[0, 0, 0, 0]] * (2 + stride) + tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"]
+
+ bbox_second = [[0, 0, 0, 0]] * len(seq0_input_ids)
+ bbox_second_sequence = (
+ bbox_second + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"][:-2]
+ )
+ overflowing_token_bbox_second_sequence_slow = tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"][-(2 + stride) :]
+ overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq0_input_ids) + tokenizer(
+ seq_1, boxes=boxes_1, add_special_tokens=False
+ )["bbox"][-(2 + stride) :]
+
+ bbox_longest_sequence = (
+ bbox_first_sequence if len(seq0_tokens) > len(seq1_tokens) else bbox_second_sequence
+ )
+ overflowing_token_bbox_longest_sequence_fast = (
+ overflowing_token_bbox_first_sequence_fast
+ if len(seq0_tokens) > len(seq1_tokens)
+ else overflowing_token_bbox_second_sequence_fast
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="longest_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ bbox = information["bbox"][0]
+ overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_longest_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_tokens, overflow_longest_sequence)
+ self.assertEqual(bbox, bbox_longest_sequence)
+
+ self.assertEqual(len(overflowing_bbox), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
+ else:
+ # No overflowing tokens when using 'longest' in python tokenizers
+ with self.assertRaises(ValueError) as context:
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="longest_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ self.assertTrue(
+ context.exception.args[0].startswith(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ bbox = information["bbox"][0]
+ overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_longest_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
+ self.assertEqual(overflowing_tokens, overflow_longest_sequence)
+ self.assertEqual(bbox, bbox_longest_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast)
+ else:
+ # No overflowing tokens when using 'longest' in python tokenizers
+ with self.assertRaises(ValueError) as context:
+ information = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ self.assertTrue(
+ context.exception.args[0].startswith(
+ "Not possible to return overflowing tokens for pair of sequences with the "
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
+ "for instance `only_second` or `only_first`."
+ )
+ )
+
+ information_first_truncated = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="only_first",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information_first_truncated["input_ids"][0]
+ overflowing_tokens = information_first_truncated["input_ids"][1]
+ bbox = information_first_truncated["bbox"][0]
+ overflowing_bbox = information_first_truncated["bbox"][0]
+ self.assertEqual(len(information_first_truncated["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_first_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_input_ids))
+ self.assertEqual(overflowing_tokens, overflow_first_sequence)
+ self.assertEqual(bbox, bbox_first_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_fast)
+ else:
+ truncated_sequence = information_first_truncated["input_ids"]
+ overflowing_tokens = information_first_truncated["overflowing_tokens"]
+ overflowing_bbox = information_first_truncated["overflowing_token_boxes"]
+ bbox = information_first_truncated["bbox"]
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_first_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, seq0_input_ids[-(2 + stride) :])
+ self.assertEqual(bbox, bbox_first_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_slow)
+
+ information_second_truncated = tokenizer(
+ question_0,
+ seq_1,
+ boxes=boxes_1,
+ max_length=len(sequence["input_ids"]) - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation="only_second",
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information_second_truncated["input_ids"][0]
+ overflowing_tokens = information_second_truncated["input_ids"][1]
+ bbox = information_second_truncated["bbox"][0]
+ overflowing_bbox = information_second_truncated["bbox"][1]
+
+ self.assertEqual(len(information_second_truncated["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_second_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_input_ids))
+ self.assertEqual(overflowing_tokens, overflow_second_sequence)
+ self.assertEqual(bbox, bbox_second_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_fast)
+ else:
+ truncated_sequence = information_second_truncated["input_ids"]
+ overflowing_tokens = information_second_truncated["overflowing_tokens"]
+ bbox = information_second_truncated["bbox"]
+ overflowing_bbox = information_second_truncated["overflowing_token_boxes"]
+
+ self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
+ self.assertEqual(truncated_sequence, truncated_second_sequence)
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, seq1_input_ids[-(2 + stride) :])
+ self.assertEqual(bbox, bbox_second_sequence)
+ self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_slow)
+
+ def test_maximum_encoding_length_single_input(self):
+ tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
+
+ sequence = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
+ total_length = len(sequence["input_ids"])
+
+ self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short")
+
+ # Test with max model input length
+ model_max_length = tokenizer.model_max_length
+ self.assertEqual(model_max_length, 100)
+ seq_1 = seq_0 * model_max_length
+ boxes_1 = boxes_0 * model_max_length
+ sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
+ total_length1 = len(sequence1["input_ids"])
+ self.assertGreater(
+ total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short"
+ )
+
+ # Simple
+ padding_strategies = (
+ [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
+ )
+ for padding_state in padding_strategies:
+ with self.subTest(f"Padding: {padding_state}"):
+ for truncation_state in [True, "longest_first", "only_first"]:
+ with self.subTest(f"Truncation: {truncation_state}"):
+ output = tokenizer(
+ seq_1,
+ boxes=boxes_1,
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+
+ self.assertEqual(len(output["input_ids"]), model_max_length)
+ self.assertEqual(len(output["bbox"]), model_max_length)
+
+ output = tokenizer(
+ [seq_1],
+ boxes=[boxes_1],
+ padding=padding_state,
+ truncation=truncation_state,
+ )
+ self.assertEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertEqual(len(output["bbox"][0]), model_max_length)
+
+ # Simple with no truncation
+ # Reset warnings
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer(seq_1, boxes=boxes_1, padding=padding_state, truncation=False)
+ self.assertNotEqual(len(output["input_ids"]), model_max_length)
+ self.assertNotEqual(len(output["bbox"]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+
+ tokenizer.deprecation_warnings = {}
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ output = tokenizer([seq_1], boxes=[boxes_1], padding=padding_state, truncation=False)
+ self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
+ self.assertNotEqual(len(output["bbox"][0]), model_max_length)
+ self.assertEqual(len(cm.records), 1)
+ self.assertTrue(
+ cm.records[0].message.startswith(
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
+ )
+ )
+ # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation
+ stride = 2
+ information = tokenizer(
+ seq_0,
+ boxes=boxes_0,
+ max_length=total_length - 2,
+ add_special_tokens=False,
+ stride=stride,
+ truncation=True,
+ return_overflowing_tokens=True,
+ # add_prefix_space=False,
+ )
+
+ # Overflowing tokens are handled quite differently in slow and fast tokenizers
+ if isinstance(tokenizer, LayoutLMv3TokenizerFast):
+ truncated_sequence = information["input_ids"][0]
+ overflowing_tokens = information["input_ids"][1]
+ # bbox = information["bbox"][0]
+ # overflowing_bbox = information["bbox"][1]
+ self.assertEqual(len(information["input_ids"]), 2)
+
+ self.assertEqual(len(truncated_sequence), total_length - 2)
+ self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
+
+ # self.assertEqual(bbox, sequence["bbox"][:-2])
+ # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
+ else:
+ truncated_sequence = information["input_ids"]
+ overflowing_tokens = information["overflowing_tokens"]
+ # bbox = information["bbox"]
+ # overflowing_bbox = information["overflowing_token_boxes"]
+ self.assertEqual(len(truncated_sequence), total_length - 2)
+ self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
+
+ self.assertEqual(len(overflowing_tokens), 2 + stride)
+ self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
+ # self.assertEqual(bbox, sequence["bbox"][:-2])
+ # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :])
+
+ @unittest.skip("LayoutLMv3 tokenizer requires boxes besides sequences.")
+ def test_pretokenized_inputs(self):
+ pass
+
+ @unittest.skip("LayoutLMv3 tokenizer always expects pretokenized inputs.")
+ def test_compare_pretokenized_inputs(self):
+ pass
+
+ @unittest.skip("LayoutLMv3 fast tokenizer does not support prepare_for_model")
+ def test_compare_prepare_for_model(self):
+ pass
+
+ @slow
+ def test_only_label_first_subword(self):
+ words = ["hello", "niels"]
+ boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))]
+ word_labels = [0, 1]
+
+ # test slow tokenizer
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
+
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained(
+ "microsoft/layoutlmv3-base",
+ only_label_first_subword=False,
+ add_visual_labels=False,
+ )
+ encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
+
+ # test fast tokenizer
+ tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
+ encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
+
+ tokenizer_r = LayoutLMv3Tokenizer.from_pretrained(
+ "microsoft/layoutlmv3-base",
+ only_label_first_subword=False,
+ add_visual_labels=False,
+ )
+ encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels)
+ self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
+
+ @slow
+ def test_layoutlmv3_integration_test(self):
+
+ tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
+ tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
+
+ # There are 3 cases:
+ # CASE 1: document image classification (training + inference), document image token classification (inference),
+ # in which case only words and normalized bounding boxes are provided to the tokenizer
+ # CASE 2: document image token classification (training),
+ # in which case one also provides word labels to the tokenizer
+ # CASE 3: document image visual question answering (inference),
+ # in which case one also provides a question to the tokenizer
+
+ # We need to test all 3 cases both on batched and non-batched inputs.
+
+ # CASE 1: not batched
+ words, boxes = self.get_words_and_boxes()
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # CASE 1: batched
+ words, boxes = self.get_words_and_boxes_batch()
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # CASE 2: not batched
+ words, boxes = self.get_words_and_boxes()
+ word_labels = [1, 2]
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'labels': [-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 2: batched
+ words, boxes = self.get_words_and_boxes_batch()
+ word_labels = [[1, 2], [2, 46]]
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'labels': [[-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [-100, 2, 46, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 3: not batched
+ question, words, boxes = self.get_question_words_and_boxes()
+
+ # fmt: off
+ expected_results = {'input_ids': [0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(question, words, boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(question, words, boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ # # CASE 3: batched
+ questions, words, boxes = self.get_question_words_and_boxes_batch()
+
+ # fmt: off
+ expected_results = {'input_ids': [[0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 141, 16, 37, 373, 116, 2, 2, 13964, 795, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [256, 38, 330, 58], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231
+ # fmt: on
+
+ encoding_p = tokenizer_p(questions, words, boxes, padding="max_length", max_length=20)
+ encoding_r = tokenizer_r(questions, words, boxes, padding="max_length", max_length=20)
+ self.assertDictEqual(dict(encoding_p), expected_results)
+ self.assertDictEqual(dict(encoding_r), expected_results)
+
+ @unittest.skip("Doesn't support another framework than PyTorch")
+ def test_np_encode_plus_sent_to_model(self):
+ pass
diff --git a/tests/maskformer/__init__.py b/tests/models/layoutxlm/__init__.py
similarity index 100%
rename from tests/maskformer/__init__.py
rename to tests/models/layoutxlm/__init__.py
diff --git a/tests/layoutxlm/test_processor_layoutxlm.py b/tests/models/layoutxlm/test_processor_layoutxlm.py
similarity index 94%
rename from tests/layoutxlm/test_processor_layoutxlm.py
rename to tests/models/layoutxlm/test_processor_layoutxlm.py
index 3964765e32060c..d0d7eec28a34a9 100644
--- a/tests/layoutxlm/test_processor_layoutxlm.py
+++ b/tests/models/layoutxlm/test_processor_layoutxlm.py
@@ -17,7 +17,6 @@
import shutil
import tempfile
import unittest
-from os.path import dirname
from typing import List
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
@@ -38,9 +37,6 @@
from transformers import LayoutLMv2FeatureExtractor, LayoutXLMProcessor
-SAMPLE_SP = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
-
-
@require_pytesseract
@require_sentencepiece
@require_tokenizers
@@ -60,11 +56,14 @@ def setUp(self):
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(feature_extractor_map) + "\n")
+ # taken from `test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_save_pretrained`
+ self.tokenizer_pretrained_name = "hf-internal-testing/tiny-random-layoutxlm"
+
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
- return self.tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)
+ return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
- return self.rust_tokenizer_class.from_pretrained(SAMPLE_SP, **kwargs)
+ return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
@@ -177,10 +176,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -198,10 +198,11 @@ def test_processor_case_1(self):
)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITCās Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITCās value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITCās brands national assets, adding to Indiaās competitiveness. It is ITCās aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
@slow
@@ -228,7 +229,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = " hello world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -243,7 +244,7 @@ def test_processor_case_2(self):
# verify input_ids
expected_decoding = " hello world"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -282,7 +283,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = " weirdly world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify labels
@@ -304,7 +305,7 @@ def test_processor_case_3(self):
# verify input_ids
expected_decoding = " my name is niels"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -344,10 +345,11 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
# fmt: off
expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President āIntroductory Remarksā Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231
# fmt: on
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -362,8 +364,9 @@ def test_processor_case_4(self):
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
+ # this was obtained with Tesseract 4.1.1
expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
@@ -396,7 +399,7 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = " What's his name? hello world"
- decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
+ decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# batched
@@ -412,11 +415,11 @@ def test_processor_case_5(self):
# verify input_ids
expected_decoding = " How old is he? hello world"
- decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
+ decoding = processor.decode(input_processor.input_ids[0].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
expected_decoding = " what's the time my name is niels"
- decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
+ decoding = processor.decode(input_processor.input_ids[1].tolist())
self.assertSequenceEqual(decoding, expected_decoding)
# verify bbox
diff --git a/tests/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
similarity index 99%
rename from tests/layoutxlm/test_tokenization_layoutxlm.py
rename to tests/models/layoutxlm/test_tokenization_layoutxlm.py
index 09ca5061fcc576..68aba50ecaf403 100644
--- a/tests/layoutxlm/test_tokenization_layoutxlm.py
+++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
@@ -14,7 +14,6 @@
# limitations under the License.
import inspect
-import os
import shutil
import tempfile
import unittest
@@ -23,6 +22,7 @@
from transformers import AddedToken, LayoutXLMTokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available
from transformers.models.layoutxlm.tokenization_layoutxlm import LayoutXLMTokenizer
from transformers.testing_utils import (
+ get_tests_dir,
is_pt_tf_cross_test,
require_pandas,
require_scatter,
@@ -32,7 +32,7 @@
slow,
)
-from ..test_tokenization_common import (
+from ...test_tokenization_common import (
SMALL_TRAINING_CORPUS,
TokenizerTesterMixin,
filter_non_english,
@@ -40,7 +40,7 @@
)
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -1543,11 +1543,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
diff --git a/tests/mbart/__init__.py b/tests/models/led/__init__.py
similarity index 100%
rename from tests/mbart/__init__.py
rename to tests/models/led/__init__.py
diff --git a/tests/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py
similarity index 97%
rename from tests/led/test_modeling_led.py
rename to tests/models/led/test_modeling_led.py
index 758834d7fe9225..e7dc31838aa313 100644
--- a/tests/led/test_modeling_led.py
+++ b/tests/models/led/test_modeling_led.py
@@ -24,9 +24,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -163,6 +163,7 @@ def get_config(self):
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
+ config.vocab_size = 300
return config
def prepare_config_and_inputs_for_common(self):
@@ -528,9 +529,26 @@ def test_seq_to_seq_generation(self):
no_repeat_ngram_size=3,
)
- EXPECTED_LEP = " the physics of @xmath0-boson will again play the central role in the frontier of particle physics if the gigaz option of the international linear collider ( ilc ) can be realized in its first phase. \n the expected sensitivity to the branching ratio of the rare decays, especially its exotic or rare processes, should be investigated comprehensively to evaluate their potential in probing new physics. in this work \n, we extend the previous studies of these decays to some new models and investigate the decays altogether. we are motivated by some recent studies on the singlet extension of the mssm, such as the next - to - minimal supersymmetric standard model ( nmssm ) @xcite and the nearly - minimal - supersymmetry - standard - model(nmssm)@xcite, where a light cp - odd higgs boson with singlet - dominant component may naturally arise from the spontaneous breaking of some approximate global symmetry. # 1#2#3#4#5#6#7#8#9#10#11#12 "
+ EXPECTED_LEP = (
+ " the physics of @xmath0-boson will again play the central role in the frontier of particle physics if the"
+ " gigaz option of the international linear collider ( ilc ) can be realized in its first phase. \n the"
+ " expected sensitivity to the branching ratio of the rare decays, especially its exotic or rare processes,"
+ " should be investigated comprehensively to evaluate their potential in probing new physics. in this work"
+ " \n, we extend the previous studies of these decays to some new models and investigate the decays"
+ " altogether. we are motivated by some recent studies on the singlet extension of the mssm, such as the"
+ " next - to - minimal supersymmetric standard model ( nmssm ) @xcite and the nearly - minimal -"
+ " supersymmetry - standard - model(nmssm)@xcite, where a light cp - odd higgs boson with singlet -"
+ " dominant component may naturally arise from the spontaneous breaking of some approximate global"
+ " symmetry. # 1#2#3#4#5#6#7#8#9#10#11#12 "
+ )
- EXPECTED_MAGNET = " the recent experiment in the surface states of the topological insulator bi@xmath0se @xmath1, however, reported that a large positive magnetoresistance becomes very linear in perpendicular magnetic field even in an opposite situation where the carrier sheet density is high that all electrons occupy more than one landau levels. \n it is striking that this observation is in conflict with abrikosov s model and also with the classical parish - littlewood model. "
+ EXPECTED_MAGNET = (
+ " the recent experiment in the surface states of the topological insulator bi@xmath0se @xmath1, however,"
+ " reported that a large positive magnetoresistance becomes very linear in perpendicular magnetic field"
+ " even in an opposite situation where the carrier sheet density is high that all electrons occupy more"
+ " than one landau levels. \n it is striking that this observation is in conflict with abrikosov s model"
+ " and also with the classical parish - littlewood model. "
+ )
generated = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
diff --git a/tests/led/test_modeling_tf_led.py b/tests/models/led/test_modeling_tf_led.py
similarity index 99%
rename from tests/led/test_modeling_tf_led.py
rename to tests/models/led/test_modeling_tf_led.py
index df115010f33e11..8075d071e6626b 100644
--- a/tests/led/test_modeling_tf_led.py
+++ b/tests/models/led/test_modeling_tf_led.py
@@ -19,8 +19,8 @@
from transformers import LEDConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
diff --git a/tests/mbart50/__init__.py b/tests/models/levit/__init__.py
similarity index 100%
rename from tests/mbart50/__init__.py
rename to tests/models/levit/__init__.py
diff --git a/tests/models/levit/test_feature_extraction_levit.py b/tests/models/levit/test_feature_extraction_levit.py
new file mode 100644
index 00000000000000..98a704b97a62d3
--- /dev/null
+++ b/tests/models/levit/test_feature_extraction_levit.py
@@ -0,0 +1,195 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LevitFeatureExtractor
+
+
+class LevitFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ do_center_crop=True,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "do_center_crop": self.do_center_crop,
+ "size": self.size,
+ }
+
+
+@require_torch
+@require_vision
+class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = LevitFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = LevitFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size,
+ self.feature_extract_tester.size,
+ ),
+ )
diff --git a/tests/models/levit/test_modeling_levit.py b/tests/models/levit/test_modeling_levit.py
new file mode 100644
index 00000000000000..725b279fd02f40
--- /dev/null
+++ b/tests/models/levit/test_modeling_levit.py
@@ -0,0 +1,427 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch LeViT model. """
+
+
+import inspect
+import unittest
+import warnings
+from math import ceil, floor
+
+from transformers import LevitConfig
+from transformers.file_utils import cached_property, is_torch_available, is_vision_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ MODEL_MAPPING,
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ )
+ from transformers.models.levit.modeling_levit import LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import LevitFeatureExtractor
+
+
+class LevitConfigTester(ConfigTester):
+ def create_and_test_config_common_properties(self):
+ config = self.config_class(**self.inputs_dict)
+ self.parent.assertTrue(hasattr(config, "hidden_sizes"))
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
+
+
+class LevitModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=64,
+ num_channels=3,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ patch_size=16,
+ hidden_sizes=[128, 256, 384],
+ num_attention_heads=[4, 6, 8],
+ depths=[2, 3, 4],
+ key_dim=[16, 16, 16],
+ drop_path_rate=0,
+ mlp_ratio=[2, 2, 2],
+ attention_ratio=[2, 2, 2],
+ initializer_range=0.02,
+ is_training=True,
+ use_labels=True,
+ num_labels=2, # Check
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.hidden_sizes = hidden_sizes
+ self.num_attention_heads = num_attention_heads
+ self.depths = depths
+ self.key_dim = key_dim
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size
+ self.attention_ratio = attention_ratio
+ self.mlp_ratio = mlp_ratio
+ self.initializer_range = initializer_range
+ self.down_ops = [
+ ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
+ ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
+ ]
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.num_labels = num_labels
+ self.initializer_range = initializer_range
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.num_labels)
+
+ config = self.get_config()
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return LevitConfig(
+ image_size=self.image_size,
+ num_channels=self.num_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ patch_size=self.patch_size,
+ hidden_sizes=self.hidden_sizes,
+ num_attention_heads=self.num_attention_heads,
+ depths=self.depths,
+ key_dim=self.key_dim,
+ drop_path_rate=self.drop_path_rate,
+ mlp_ratio=self.mlp_ratio,
+ attention_ratio=self.attention_ratio,
+ initializer_range=self.initializer_range,
+ down_ops=self.down_ops,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = LevitModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ image_size = (self.image_size, self.image_size)
+ height, width = image_size[0], image_size[1]
+ for _ in range(4):
+ height = floor(((height + 2 * self.padding - self.kernel_size) / self.stride) + 1)
+ width = floor(((width + 2 * self.padding - self.kernel_size) / self.stride) + 1)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, ceil(height / 4) * ceil(width / 4), self.hidden_sizes[-1]),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.num_labels
+ model = LevitForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class LevitModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as Levit does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (LevitModel, LevitForImageClassification, LevitForImageClassificationWithTeacher)
+ if is_torch_available()
+ else ()
+ )
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ has_attentions = False
+
+ def setUp(self):
+ self.model_tester = LevitModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LevitConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ @unittest.skip(reason="Levit does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Levit does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ @unittest.skip(reason="Levit does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = len(self.model_tester.depths) + 1
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ height, width = image_size[0], image_size[1]
+ for _ in range(4):
+ height = floor(
+ (
+ (height + 2 * self.model_tester.padding - self.model_tester.kernel_size)
+ / self.model_tester.stride
+ )
+ + 1
+ )
+ width = floor(
+ (
+ (width + 2 * self.model_tester.padding - self.model_tester.kernel_size)
+ / self.model_tester.stride
+ )
+ + 1
+ )
+ # verify the first hidden states (first block)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [
+ height * width,
+ self.model_tester.hidden_sizes[0],
+ ],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class.__name__ == "LevitForImageClassificationWithTeacher":
+ del inputs_dict["labels"]
+
+ return inputs_dict
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ # special case for LevitForImageClassificationWithTeacher model
+ def test_training(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ # LevitForImageClassificationWithTeacher supports inference-only
+ if (
+ model_class in get_values(MODEL_MAPPING)
+ or model_class.__name__ == "LevitForImageClassificationWithTeacher"
+ ):
+ continue
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ if not self.model_tester.is_training:
+ return
+
+ config.use_cache = False
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
+ continue
+ # LevitForImageClassificationWithTeacher supports inference-only
+ if model_class.__name__ == "LevitForImageClassificationWithTeacher":
+ continue
+ model = model_class(config)
+ model.gradient_checkpointing_enable()
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_problem_types(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ problem_types = [
+ {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
+ {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
+ {"title": "regression", "num_labels": 1, "dtype": torch.float},
+ ]
+
+ for model_class in self.all_model_classes:
+ if (
+ model_class
+ not in [
+ *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
+ ]
+ or model_class.__name__ == "LevitForImageClassificationWithTeacher"
+ ):
+ continue
+
+ for problem_type in problem_types:
+ with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
+
+ config.problem_type = problem_type["title"]
+ config.num_labels = problem_type["num_labels"]
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+
+ if problem_type["num_labels"] > 1:
+ inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
+
+ inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
+
+ # This tests that we do not trigger the warning form PyTorch "Using a target size that is different
+ # to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure
+ # they have the same size." which is a symptom something in wrong for the regression problem.
+ # See https://github.com/huggingface/transformers/issues/11780
+ with warnings.catch_warnings(record=True) as warning_list:
+ loss = model(**inputs).loss
+ for w in warning_list:
+ if "Using a target size that is different to the input size" in str(w.message):
+ raise ValueError(
+ f"Something is going wrong in the regression problem: intercepted {w.message}"
+ )
+
+ loss.backward()
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LevitModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class LevitModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return LevitFeatureExtractor.from_pretrained(LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = LevitForImageClassificationWithTeacher.from_pretrained(LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(
+ torch_device
+ )
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([1.0448, -0.3745, -1.8317]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/megatron_bert/__init__.py b/tests/models/longformer/__init__.py
similarity index 100%
rename from tests/megatron_bert/__init__.py
rename to tests/models/longformer/__init__.py
diff --git a/tests/longformer/test_modeling_longformer.py b/tests/models/longformer/test_modeling_longformer.py
similarity index 99%
rename from tests/longformer/test_modeling_longformer.py
rename to tests/models/longformer/test_modeling_longformer.py
index 6b3a8752ed4ee7..c1839d67d36c1c 100644
--- a/tests/longformer/test_modeling_longformer.py
+++ b/tests/models/longformer/test_modeling_longformer.py
@@ -19,8 +19,8 @@
from transformers import LongformerConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -113,6 +113,11 @@ def get_config(self):
attention_window=self.attention_window,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
diff --git a/tests/longformer/test_modeling_tf_longformer.py b/tests/models/longformer/test_modeling_tf_longformer.py
similarity index 99%
rename from tests/longformer/test_modeling_tf_longformer.py
rename to tests/models/longformer/test_modeling_tf_longformer.py
index d483682677227f..12c19e566e95d4 100644
--- a/tests/longformer/test_modeling_tf_longformer.py
+++ b/tests/models/longformer/test_modeling_tf_longformer.py
@@ -19,8 +19,8 @@
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/megatron_gpt2/__init__.py b/tests/models/longt5/__init__.py
similarity index 100%
rename from tests/megatron_gpt2/__init__.py
rename to tests/models/longt5/__init__.py
diff --git a/tests/models/longt5/test_modeling_flax_longt5.py b/tests/models/longt5/test_modeling_flax_longt5.py
new file mode 100644
index 00000000000000..9406e292d177a7
--- /dev/null
+++ b/tests/models/longt5/test_modeling_flax_longt5.py
@@ -0,0 +1,757 @@
+# coding=utf-8
+# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import tempfile
+import unittest
+
+import numpy as np
+
+import transformers
+from transformers import is_flax_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import (
+ is_pt_flax_cross_test,
+ require_flax,
+ require_sentencepiece,
+ require_tokenizers,
+ slow,
+)
+
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+
+
+if is_flax_available():
+ import os
+
+ # The slow tests are often failing with OOM error on GPU
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+
+ import jax
+ import jax.numpy as jnp
+ from flax.core.frozen_dict import unfreeze
+ from flax.traverse_util import flatten_dict
+ from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING, AutoTokenizer, LongT5Config
+ from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
+ from transformers.models.longt5.modeling_flax_longt5 import (
+ FlaxLongT5ForConditionalGeneration,
+ FlaxLongT5Model,
+ shift_tokens_right,
+ )
+
+
+class FlaxLongT5ModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ decoder_seq_length=9,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ is_training=True,
+ use_attention_mask=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ eos_token_id=1,
+ pad_token_id=0,
+ decoder_start_token_id=0,
+ scope=None,
+ decoder_layers=None,
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.decoder_seq_length = decoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.scope = None
+ self.decoder_layers = decoder_layers
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+ decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ decoder_attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+ decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ config = LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ return (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ):
+ model = FlaxLongT5Model(config=config)
+ result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+ result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ decoder_output = result.last_hidden_state
+ encoder_output = result.encoder_last_hidden_state
+
+ self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
+ self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size))
+
+ def check_use_cache_forward_with_attn_mask(
+ self,
+ model_class_name,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ):
+ max_decoder_length = 20
+ model = model_class_name(config)
+
+ encoder_outputs = model.encode(input_ids)
+
+ # prevent fully zero'd out attention mask
+ decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
+
+ decoder_attention_mask_cache = jnp.concatenate(
+ [
+ decoder_attention_mask,
+ jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
+ ],
+ axis=-1,
+ )
+
+ past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
+
+ outputs_cache = model.decode(
+ decoder_input_ids[:, :-1],
+ encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask_cache,
+ past_key_values=past_key_values,
+ )
+ outputs_cache_next = model.decode(
+ decoder_input_ids[:, -1:],
+ encoder_outputs,
+ past_key_values=outputs_cache.past_key_values,
+ decoder_attention_mask=decoder_attention_mask_cache,
+ )
+
+ outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_flax
+class FlaxLongT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FlaxLongT5Model, FlaxLongT5ForConditionalGeneration) if is_flax_available() else ()
+ all_generative_model_classes = (FlaxLongT5ForConditionalGeneration,) if is_flax_available() else ()
+ is_encoder_decoder = True
+
+ def setUp(self):
+ self.model_tester = FlaxLongT5ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_v1_1(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ # check that gated gelu feed forward and different word embeddings work
+ config = config_and_inputs[0]
+ config.tie_word_embeddings = False
+ config.feed_forward_proj = "gated-gelu"
+ self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
+
+ def test_use_cache_forward_with_attn_mask(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
+
+ def test_encode(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ with self.subTest(model_class.__name__):
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config)
+
+ @jax.jit
+ def encode_jitted(input_ids, attention_mask=None, **kwargs):
+ return model.encode(input_ids=input_ids, attention_mask=attention_mask)
+
+ with self.subTest("JIT Enabled"):
+ jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ with self.subTest("JIT Disabled"):
+ with jax.disable_jit():
+ outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
+
+ self.assertEqual(len(outputs), len(jitted_outputs))
+ for jitted_output, output in zip(jitted_outputs, outputs):
+ self.assertEqual(jitted_output.shape, output.shape)
+
+ def test_decode(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ with self.subTest(model_class.__name__):
+ model = model_class(config)
+ encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
+
+ prepared_inputs_dict = {
+ "decoder_input_ids": inputs_dict["decoder_input_ids"],
+ "decoder_attention_mask": inputs_dict["decoder_attention_mask"],
+ "encoder_outputs": encoder_outputs,
+ }
+
+ @jax.jit
+ def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
+ return model.decode(
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ )
+
+ with self.subTest("JIT Enabled"):
+ jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
+
+ with self.subTest("JIT Disabled"):
+ with jax.disable_jit():
+ outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
+
+ self.assertEqual(len(outputs), len(jitted_outputs))
+ for jitted_output, output in zip(jitted_outputs, outputs):
+ self.assertEqual(jitted_output.shape, output.shape)
+
+ def test_shift_right(self):
+ decoder_start_token_id = 0
+ pad_token_id = 1
+ labels = np.arange(2, 102).reshape(5, 20)
+ labels[:2, 15:] = -100
+
+ decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
+ np_decoder_input_ids = np.array(decoder_input_ids)
+
+ padded_slice = np_decoder_input_ids[:2, (15 + 1) :]
+ self.assertTrue((padded_slice == 1).all())
+
+ not_padded_slice = np_decoder_input_ids[2:, 1:]
+ rolled_labels = np.roll(labels[2:], 1)[:, 1:]
+ self.assertTrue((not_padded_slice == rolled_labels).all())
+ self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())
+
+ # overwrite since special base model prefix is used
+ def test_save_load_from_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ def test_save_load_to_base(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_length = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ block_len = getattr(self.model_tester, "block_len", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_from_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = base_class(config)
+ base_params = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # save pt model
+ pt_model.save_pretrained(tmpdirname)
+ head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
+
+ for key in base_param_from_head.keys():
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+ # overwrite since special base model prefix is used
+ @is_pt_flax_cross_test
+ def test_save_load_bf16_to_base_pt(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
+
+ for model_class in self.all_model_classes:
+ if model_class == base_class:
+ continue
+
+ model = model_class(config)
+ model.params = model.to_bf16(model.params)
+ base_params_from_head = flatten_dict(unfreeze(model.params))
+
+ # convert Flax model to PyTorch model
+ pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
+ pt_model = pt_model_class(config).eval()
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
+
+ # check that all base model weights are loaded correctly
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
+
+ base_params = flatten_dict(unfreeze(base_model.params))
+
+ for key in base_params_from_head.keys():
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+
+
+class FlaxLongT5TGlobalModelTest(FlaxLongT5ModelTest):
+ def setUp(self):
+ self.model_tester = FlaxLongT5ModelTester(self, encoder_attention_type="transient-global")
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_length = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_len = encoder_seq_length // global_block_size
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+
+@require_sentencepiece
+@require_tokenizers
+@require_flax
+class FlaxLongT5ModelIntegrationTests(unittest.TestCase):
+ model_path = "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
+
+ def expected_summary(self):
+ return [
+ "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
+ " developing world . it provides an excellent resolution for visualization of the coronary arteries for"
+ " catheter - based or operating interventions . although the association of this technique with major"
+ " complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
+ " and noncardiac complications . computed tomography coronary angiography is a promising technique for the"
+ " evaluation of cad noninvasively . it assesses disease within the coronary artery and provides"
+ " qualitative and quantitative information about nonobstructive atherosclerotic plaque"
+ ]
+
+ @slow
+ def test_summarization(self):
+ model = FlaxLongT5ForConditionalGeneration.from_pretrained(self.model_path)
+ tok = AutoTokenizer.from_pretrained(self.model_path)
+
+ ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
+ although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
+ a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
+ wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
+ 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
+ ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
+ aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
+ roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
+ ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
+ duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
+ nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
+ ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
+ ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
+ y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
+ re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
+ ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
+ aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
+ gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
+ ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
+ hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
+ lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
+ hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
+ e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
+ used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
+ agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
+ positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
+ er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
+ lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
+ irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
+ giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
+ software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
+ ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
+ s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
+ icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
+ confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
+ in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
+ double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
+ dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
+ coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
+ iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
+ who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
+ tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
+ 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
+ patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
+ or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
+ in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
+ nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
+ retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
+ lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
+ iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
+ a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
+ could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
+ ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
+ s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
+ rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
+
+ dct = tok(
+ [ARTICLE],
+ max_length=1024,
+ padding="max_length",
+ truncation=True,
+ return_tensors="np",
+ )
+
+ hypotheses_batch = model.generate(
+ **dct,
+ num_beams=4,
+ length_penalty=2.0,
+ max_length=142,
+ min_length=56,
+ do_sample=False,
+ early_stopping=True,
+ ).sequences
+
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ self.assertListEqual(
+ self.expected_summary(),
+ decoded,
+ )
diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py
new file mode 100644
index 00000000000000..65375e0fafdb53
--- /dev/null
+++ b/tests/models/longt5/test_modeling_longt5.py
@@ -0,0 +1,1313 @@
+# coding=utf-8
+# Copyright 2022 Google LongT5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import tempfile
+import unittest
+
+from transformers import LongT5Config, is_torch_available
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
+from transformers.utils import cached_property
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ AutoTokenizer,
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ )
+ from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+class LongT5ModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ decoder_seq_length=9,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ is_training=True,
+ use_attention_mask=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ eos_token_id=1,
+ pad_token_id=0,
+ decoder_start_token_id=0,
+ scope=None,
+ decoder_layers=None,
+ large_model_config_path="google/long-t5-local-large",
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.decoder_seq_length = decoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.decoder_seq_length
+ self.is_training = is_training
+ self.use_attention_mask = use_attention_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.decoder_start_token_id = decoder_start_token_id
+ self.scope = None
+ self.decoder_layers = decoder_layers
+ self.large_model_config_path = large_model_config_path
+
+ def get_large_model_config(self):
+ return LongT5Config.from_pretrained(self.large_model_config_path)
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+ decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ decoder_attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+ decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
+
+ lm_labels = None
+ if self.use_labels:
+ lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
+
+ config = self.get_config()
+
+ return (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ )
+
+ def get_pipeline_config(self):
+ return LongT5Config(
+ vocab_size=166, # longt5 forces 100 extra tokens
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ def get_config(self):
+ return LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_decoder_layers=self.decoder_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.decoder_start_token_id,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ def check_prepare_lm_labels_via_shift_left(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # make sure that lm_labels are correctly padded from the right
+ lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
+
+ # add casaul pad token mask
+ triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
+ lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
+ decoder_input_ids = model._shift_right(lm_labels)
+
+ for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
+ # first item
+ self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
+ if i < decoder_input_ids_slice.shape[-1]:
+ if i < decoder_input_ids.shape[-1] - 1:
+ # items before diagonal
+ self.parent.assertListEqual(
+ decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
+ )
+ # pad items after diagonal
+ if i < decoder_input_ids.shape[-1] - 2:
+ self.parent.assertListEqual(
+ decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
+ )
+ else:
+ # all items after square
+ self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+ result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+ decoder_output = result.last_hidden_state
+ decoder_past = result.past_key_values
+ encoder_output = result.encoder_last_hidden_state
+
+ self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
+ self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
+ # There should be `num_layers` key value embeddings stored in decoder_past
+ self.parent.assertEqual(len(decoder_past), config.num_layers)
+ # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
+ self.parent.assertEqual(len(decoder_past[0]), 4)
+
+ def create_and_check_with_lm_head(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval()
+ outputs = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ labels=lm_labels,
+ )
+ self.parent.assertEqual(len(outputs), 4)
+ self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
+ self.parent.assertEqual(outputs["loss"].size(), ())
+
+ def create_and_check_decoder_model_past(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder().to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, use_cache=True)
+ outputs_use_cache_conf = model(input_ids)
+ outputs_no_past = model(input_ids, use_cache=False)
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_decoder_model_attention_mask_past(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder()
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ half_seq_length = input_ids.shape[-1] // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5Model(config=config).get_decoder().to(torch_device).eval()
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_generate_with_past_key_values(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ model = LongT5ForConditionalGeneration(config=config).to(torch_device).eval()
+ torch.manual_seed(0)
+ output_without_past_cache = model.generate(
+ input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
+ )
+ torch.manual_seed(0)
+ output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
+ self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
+
+ def create_and_check_encoder_decoder_shared_weights(
+ self,
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ):
+ for model_class in [LongT5Model, LongT5ForConditionalGeneration]:
+ torch.manual_seed(0)
+ model = model_class(config=config).to(torch_device).eval()
+ # load state dict copies weights but does not tie them
+ model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
+
+ torch.manual_seed(0)
+ tied_config = copy.deepcopy(config)
+ tied_config.tie_encoder_decoder = True
+ tied_model = model_class(config=tied_config).to(torch_device).eval()
+
+ model_result = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ tied_model_result = tied_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ # check that models has less parameters
+ self.parent.assertLess(
+ sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
+ )
+ random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
+
+ # check that outputs are equal
+ self.parent.assertTrue(
+ torch.allclose(
+ model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
+ )
+ )
+
+ # check that outputs after saving and loading are equal
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ tied_model.save_pretrained(tmpdirname)
+ tied_model = model_class.from_pretrained(tmpdirname)
+ tied_model.to(torch_device)
+ tied_model.eval()
+
+ # check that models has less parameters
+ self.parent.assertLess(
+ sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
+ )
+ random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
+
+ tied_model_result = tied_model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ )
+
+ # check that outputs are equal
+ self.parent.assertTrue(
+ torch.allclose(
+ model_result[0][0, :, random_slice_idx],
+ tied_model_result[0][0, :, random_slice_idx],
+ atol=1e-4,
+ )
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "use_cache": False,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else ()
+ all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else ()
+ fx_compatible = False
+ test_pruning = False
+ test_torchscript = True
+ test_resize_embeddings = True
+ test_model_parallel = False
+ is_encoder_decoder = True
+
+ def setUp(self):
+ self.model_tester = LongT5ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_shift_right(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_with_lm_head(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
+
+ def test_decoder_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
+
+ def test_decoder_model_past_with_attn_mask(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
+
+ def test_decoder_model_past_with_3d_attn_mask(self):
+ (
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ ) = self.model_tester.prepare_config_and_inputs()
+
+ attention_mask = ids_tensor(
+ [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
+ vocab_size=2,
+ )
+ decoder_attention_mask = ids_tensor(
+ [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
+ vocab_size=2,
+ )
+
+ self.model_tester.create_and_check_decoder_model_attention_mask_past(
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask,
+ decoder_attention_mask,
+ lm_labels,
+ )
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_generate_with_past_key_values(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs)
+
+ def test_encoder_decoder_shared_weights(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = LongT5Model.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_export_to_onnx(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ model = LongT5Model(config_and_inputs[0]).to(torch_device)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ torch.onnx.export(
+ model,
+ (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
+ f"{tmpdirname}/longt5_test.onnx",
+ export_params=True,
+ opset_version=13,
+ input_names=["input_ids", "decoder_input_ids"],
+ )
+
+ def test_generate_with_head_masking(self):
+ attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config = config_and_inputs[0]
+ max_length = config_and_inputs[1].shape[-1] + 3
+ model = LongT5ForConditionalGeneration(config).eval()
+ model.to(torch_device)
+
+ head_masking = {
+ "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
+ "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
+ "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
+ }
+
+ for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
+ head_masks = {name: mask}
+ # Explicitly pass decoder_head_mask as it is required from LONGT5 model when head_mask specified
+ if name == "head_mask":
+ head_masks["decoder_head_mask"] = torch.ones(
+ config.num_decoder_layers, config.num_heads, device=torch_device
+ )
+
+ out = model.generate(
+ config_and_inputs[1],
+ num_beams=1,
+ max_length=max_length,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ **head_masks,
+ )
+ # We check the state of decoder_attentions and cross_attentions just from the last step
+ attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
+ self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ block_len = getattr(self.model_tester, "block_len", None)
+
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+ def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
+ block_len = getattr(self.model_tester, "block_len", None)
+ encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len)
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [layer_attentions.shape for layer_attentions in attentions],
+ [encoder_expected_shape] * len(attentions),
+ )
+
+
+@require_torch
+class LongT5TGlobalModelTest(LongT5ModelTest):
+ def setUp(self):
+ self.model_tester = LongT5ModelTester(
+ self, encoder_attention_type="transient-global", large_model_config_path="google/long-t5-tglobal-large"
+ )
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_len = encoder_seq_length // global_block_size
+
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
+
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+ def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
+ block_len = getattr(self.model_tester, "block_len", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", None)
+ global_seq_length = seq_length // global_block_size
+ encoder_expected_shape = (
+ batch_size,
+ 1,
+ config.num_attention_heads,
+ block_len,
+ 3 * block_len + global_seq_length,
+ )
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [layer_attentions.shape for layer_attentions in attentions],
+ [encoder_expected_shape] * len(attentions),
+ )
+
+
+class LongT5EncoderOnlyModelTester:
+ def __init__(
+ self,
+ parent,
+ vocab_size=99,
+ batch_size=13,
+ encoder_seq_length=7,
+ local_radius=5,
+ encoder_attention_type="local",
+ global_block_size=3,
+ # For common tests
+ use_attention_mask=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ d_ff=37,
+ relative_attention_num_buckets=8,
+ is_training=False,
+ dropout_rate=0.1,
+ initializer_factor=0.002,
+ is_encoder_decoder=False,
+ eos_token_id=1,
+ pad_token_id=0,
+ scope=None,
+ large_model_config_path="google/long-t5-local-large",
+ ):
+
+ self.parent = parent
+ self.batch_size = batch_size
+ self.encoder_seq_length = encoder_seq_length
+ self.local_radius = local_radius
+ self.block_len = local_radius + 1
+ self.encoder_attention_type = encoder_attention_type
+ self.global_block_size = global_block_size
+ # For common tests
+ self.seq_length = self.encoder_seq_length
+ self.use_attention_mask = use_attention_mask
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.d_ff = d_ff
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.dropout_rate = dropout_rate
+ self.initializer_factor = initializer_factor
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.is_encoder_decoder = is_encoder_decoder
+ self.scope = None
+ self.is_training = is_training
+ self.large_model_config_path = large_model_config_path
+
+ def get_large_model_config(self):
+ return LongT5Config.from_pretrained(self.large_model_config_path)
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
+
+ attention_mask = None
+ if self.use_attention_mask:
+ attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
+
+ config = LongT5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.d_ff,
+ d_kv=self.hidden_size // self.num_attention_heads,
+ num_layers=self.num_hidden_layers,
+ num_heads=self.num_attention_heads,
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
+ dropout_rate=self.dropout_rate,
+ initializer_factor=self.initializer_factor,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.pad_token_id,
+ pad_token_id=self.pad_token_id,
+ is_encoder_decoder=self.is_encoder_decoder,
+ local_radius=self.local_radius,
+ encoder_attention_type=self.encoder_attention_type,
+ global_block_size=self.global_block_size,
+ )
+
+ return (
+ config,
+ input_ids,
+ attention_mask,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ attention_mask,
+ ):
+ model = LongT5EncoderModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+ result = model(input_ids=input_ids)
+ encoder_output = result.last_hidden_state
+
+ self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ ) = config_and_inputs
+
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (LongT5EncoderModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_torchscript = True
+ test_resize_embeddings = False
+ test_model_parallel = False
+
+ def setUp(self):
+ self.model_tester = LongT5EncoderOnlyModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ block_len = getattr(self.model_tester, "block_len", 4)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len],
+ )
+
+
+class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
+ def setUp(self):
+ self.model_tester = LongT5EncoderOnlyModelTester(
+ self, encoder_attention_type="transient-global", large_model_config_path="google/long-t5-tglobal-large"
+ )
+ self.config_tester = ConfigTester(self, config_class=LongT5Config, d_model=37)
+
+ def test_attention_outputs(self):
+ if not self.has_attentions:
+ pass
+
+ else:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ block_len = getattr(self.model_tester, "block_len", None)
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ global_block_size = getattr(self.model_tester, "global_block_size", 4)
+ global_seq_len = seq_len // global_block_size
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
+ )
+
+
+def use_task_specific_params(model, task):
+ model.config.update(model.config.task_specific_params[task])
+
+
+@require_torch
+@require_sentencepiece
+@require_tokenizers
+class LongT5ModelIntegrationTests(unittest.TestCase):
+ @cached_property
+ def model(self):
+ return LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps").to(
+ torch_device
+ )
+
+ @cached_property
+ def tokenizer(self):
+ return AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
+
+ def expected_summary(self):
+ return [
+ "background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in"
+ " developing world . it provides an excellent resolution for visualization of the coronaryarteries for"
+ " catheter - based or operating interventions . although the association of this technique with major"
+ " complications such as mortality is highly uncommon , it is frequently associated with various cardiac"
+ " and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the"
+ " diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for"
+ " major noncoron ary cardiac surgery referred"
+ ]
+
+ @slow
+ def test_summarization(self):
+ model = self.model
+ tok = self.tokenizer
+
+ ARTICLE = """coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . \n it provides an excellent resolution for visualization of the coronary arteries for catheter - based or operating interventions . \n
+ although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications . computed tomography ( ct ) coronary angiography is
+ a promising technique for the evaluation of cad noninvasively . \n it assesses disease within the coronary artery and provides qualitative and quantitative information about nonobstructive atherosclerotic plaque burden within the vessel
+ wall . \n thus , ct angiography - based disease evaluation may provide clinically more significant information than conventional angiography . the introduction of multi - slice computed tomography ( msct ) technology such as 64-slice , 12
+ 8-slice , 256-slice , and now 320-slice msct has produced a high diagnostic accuracy of ct coronary angiography . \n it has consistently showed to have a very high negative predictive value ( well above 90% ) in ruling out patients with s
+ ignificant cad defined as coronary luminal stenosis of > 50% . \n the american college of cardiology / american heart association recommends that coronary angiography should be performed before valve surgery in men aged > 40 years , women
+ aged > 35 years with coronary risk factors and in postmenopausal women . \n the prevalence of cad in patients undergoing valve replacement is 2040% in developed countries . in the previous studies , \n the incidence of angiographically p
+ roven cad in acquired valvular diseases has been shown to vary widely from 9% to 41% . in aortic stenosis , \n we aimed to report the diagnostic performance of 128-slice ct coronary angiography in 50 patients undergoing for major noncoron
+ ary cardiac surgery referred for diagnostic invasive coronary angiography to assess the extent and severity of coronary stenosis . \n during january 2013 to december 2014 , we enrolled fifty major noncoronary cardiac surgery patients sche
+ duled for invasive coronary angiography who fulfilled the following inclusion criteria of age 40 years , having low or intermediate probability of cad , left ventricular ejection fraction ( lvef ) > 35% , and patient giving informed conse
+ nt for undergoing msct and conventional coronary angiography . \n those having any contraindication for contrast injection , lvef < 35% , high pretest probability of cad , and hemodynamic instability were excluded from the study . \n pati
+ ents with heart rates of > 70 bpm received ( unless they had known overt heart failure or electrocardiogram ( ecg ) atrioventricular conduction abnormalities ) a single oral dose of 100 mg metoprolol 45 min before the scan . \n patients w
+ ith heart rates of > 80 bpm received an additional oral dose of metoprolol if not contraindicated . \n all patients were scanned with a 128-slice ct scanner ( siemens , somatom definition as ) equipped with a new feature in msct technolog
+ y , so - called z - axis flying - focus technology . \n the central 32 detector rows acquire 0.6-mm slices , and the flying - focus spot switches back and forth between 2 z positions between each reading . \n two slices per detector row a
+ re acquired , which results in a higher oversampling rate in the z - axis , thereby reducing artifacts related to the spiral acquisition and improving spatial resolution down to 0.4 mm . \n a bolus of 6580 ml contrast material ( omnipaque
+ ) was injected through an arm vein at a flow rate of 5 ml / s . \n a bolus tracking technique was used to synchronize the arrival of contrast in the coronary arteries with the initiation of the scan . to monitor the arrival of contrast m
+ aterial , \n axial scans were obtained at the level of the ascending aorta with a delay of 10 s after the start of the contrast injection . \n the scan was automatically started when a threshold of 150 hounsfield units was reached in a re
+ gion of interest positioned in the ascending aorta . \n images were reconstructed with ecg gating to obtain optimal , motion - free image quality . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a s
+ ingle observer unaware of the multi - slice ct results identified coronary lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiograp
+ hy . \n lesions were classified as having nonsignificant disease ( luminal irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean
+ lumen diameter reduction was 50% using a validated quantitative coronary angiography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiograp
+ hy . \n total calcium scores of all patients were calculated with dedicated software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of th
+ e number , areas , and peak hounsfield units of the detected calcified lesions . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were
+ used to identify coronary lesions and ( curved ) multiplanar reconstructions to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the di
+ agnostic performance of ct coronary angiography for the detection of significant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and
+ positive and negative likelihood ratios with the corresponding exact 95% of confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease p
+ er vessel ) , and patient by patient ( no or any disease per patient ) . \n all scans were performed within 2 weeks of the msct coronary diagnostic angiogram . a single observer unaware of the multi - slice ct results identified coronary
+ lesion as a single vessel , double vessel , or triple vessel disease . \n all lesion , regardless of size , were included for comparison with ct coronary angiography . \n lesions were classified as having nonsignificant disease ( luminal
+ irregularities or < 50% stenosis ) or as having significant stenosis . \n stenosis was evaluated in two orthogonal views and classified as significant if the mean lumen diameter reduction was 50% using a validated quantitative coronary an
+ giography ( qca ) . \n all scans were analyzed independently by a radiologist and a cardiologist who were unaware of the results of conventional coronary angiography . \n total calcium scores of all patients were calculated with dedicated
+ software and expressed as agatston scores . \n the agatston score is a commonly used scoring method that calculates the total amount of calcium on the basis of the number , areas , and peak hounsfield units of the detected calcified lesi
+ ons . \n all available coronary segments were visually scored for the presence of > 50% considered as significant stenosis . \n maximum intensity projections were used to identify coronary lesions and ( curved ) multiplanar reconstruction
+ s to classify lesions as significant or nonsignificant . \n data were analyzed using statistical system spss version 20 software ( chicago , il , usa ) . \n the diagnostic performance of ct coronary angiography for the detection of signif
+ icant lesions in coronary arteries with qca as the standard of reference is presented as sensitivity , specificity , positive and negative predictive values , and positive and negative likelihood ratios with the corresponding exact 95% of
+ confidence interval ( cis ) . \n comparison between ct and conventional coronary angiography was performed on the two level vessel by vessel ( no or any disease per vessel ) , and patient by patient ( no or any disease per patient ) . \n
+ in this study , 29 ( 58% ) subjects were female , and 21 ( 42% ) were male showing an average age of 50.36 8.39 years . \n of fifty patients 24 ( 48% ) , 13 ( 26% ) , eight ( 16% ) , and five ( 10% ) underwent mitral valve replacement ,
+ double valve replacement ( dvr ) , aortic valve replacement , and other surgeries , respectively . \n high distribution of cad risk factors such as hypertension ( 24% ) , smoking ( 22% ) , and dyslipidemia ( 18% ) was observed in the stu
+ dy group . \n the mean creatinine level was 0.766 0.17 and average dye used in conventional angiography was 48.5 26.6 whereas for ct angiography it was 72.8 6.32 . \n average radiation dose in conventional coronary angiography and msct
+ coronary angiography was 5.2 msv and 9.2 msv , respectively . \n the majority of the patients had sinus rhythm ( 68% ) , whereas atrial fibrillation was found in 32% of the subjects . \n patients included in the study had low to intermed
+ iate probability of cad . in this study , three patients had complications after conventional angiography . \n complications were of local site hematoma , acute kidney injury managed conservatively , and acute heart failure . \n a patient
+ who developed hematoma was obese female patients with body mass index > 30 kg / m . \n the patient suffered from pseudoaneurysm , had hospitalized for 9 days , which leads to increased morbidity and cost of hospital stay . \n the diagnos
+ tic accuracy of ct coronary angiography was evaluated regarding true positive , true negative values and is presented in table 1 . the overall sensitivity and \n specificity of ct angiography technique was 100% ( 95% ci : 39.76%100% ) and
+ 91.30% ( 95% ci : 79.21%97.58% ) , respectively [ table 2 ] . \n the positive predictive value ( 50% ; 95% ci : 15.70%84.30% ) and negative predictive value ( 100% ; 95% ci : 91.59%100% ) of ct angiography were also fairly high in these
+ patients . \n recent reports from multiple studies demonstrated that recent - generation msct scanners showed promise for noninvasive detection of coronary stenosis however , until now no studies were found regarding the clinical efficacy
+ or prognostic value of 128-slice ct coronary angiography versus conventional invasive coronary angiography in the diagnosis of patients planned for major noncoronary surgeries such as dvr , bentall , atrial septal defect closure , etc .
+ in our study , we reported 8% cad prevalence in patients planned for major noncoronary cardiac surgery . \n we performed conventional and msct coronary angiography in all patients and the results showed that ct coronary angiography with i
+ nvasive coronary angiography as the reference standard had a considerably high sensitivity ( 100% ) and specificity ( 95.65% ) . \n the health economic model using invasive coronary angiography as the reference standard showed that at a p
+ retest probability of cad of 70% or lower , ct coronary angiography resulted in lower cost per patient with a true positive diagnosis . at a pretest probability of cad of 70% or higher , invasive coronary angiography was associated with a
+ lower cost per patient with a true positive diagnosis . in our study population , \n two patients developed local site complications in the form of hematoma and pseudoaneurysm after conventional angiography . \n hence , msct coronary ang
+ iography will be more favorable in female obese patients with intermediate likelihood of cad . \n hence , msct coronary angiography will be cost - effective in patients of valvular heart diseases . \n however , ct angiography suffers from
+ a drawback that average amount of dye used in msct coronary angiography were 72.8 6.32 ml which is higher than average amount of dye required for conventional angiography ( 48.6 26.6 ml ) . \n hence , the use of ct coronary angiography
+ could not be used in patients with known renal dysfunction , where reduction of contrast dye load is highly advocated . \n our results show that 128-slice ct coronary angiography is a reliable technique to detect coronary stenosis in pat
+ ients planned for noncoronary cardiac surgery . \n although there has been important technological progress in the development of ct coronary angiography , its clinical application remains limited . \n a study wth large numbers of patient
+ s is required for the recommendation of only ct coronary angiography for the coronary evaluation in major non - cardiac surgeries . \n mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , guja
+ rat , india ) . \n u.n . mehta institute of cardiology and research center ( affiliated to bj medical college , ahmedabad , gujarat , india ) . \n """
+
+ dct = tok(
+ [ARTICLE],
+ max_length=1024,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).to(torch_device)
+
+ hypotheses_batch = model.generate(
+ **dct,
+ num_beams=4,
+ length_penalty=2.0,
+ max_length=142,
+ min_length=56,
+ no_repeat_ngram_size=3,
+ do_sample=False,
+ early_stopping=True,
+ )
+
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ self.assertListEqual(
+ self.expected_summary(),
+ decoded,
+ )
+
+ @slow
+ def test_inference_hidden_states(self):
+ model = self.model
+
+ input_ids = torch.tensor(
+ [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+ decoder_input_ids = torch.tensor(
+ [[100, 19, 3, 9, 7142, 1200, 145, 8, 1252, 14145, 2034, 812, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+ attention_mask = torch.tensor(
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ output = model(
+ input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, output_hidden_states=True
+ )
+
+ # check if encoder_outputs match
+ expected_output_slice = torch.tensor([0.0629, -0.1294, -0.0089, 0.0772, 0.0663], device=torch_device)
+ self.assertTrue(torch.allclose(output.encoder_hidden_states[-1][0, 0, :5], expected_output_slice, atol=1e-4))
+
+ # check if logits match
+ expected_output_slice = torch.tensor([5.5231, 6.1058, 3.1766, 8.2391, -5.9453], device=torch_device)
+ self.assertTrue(torch.allclose(output.logits[0, 0, :5], expected_output_slice, atol=1e-4))
diff --git a/tests/mluke/__init__.py b/tests/models/luke/__init__.py
similarity index 100%
rename from tests/mluke/__init__.py
rename to tests/models/luke/__init__.py
diff --git a/tests/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py
similarity index 95%
rename from tests/luke/test_modeling_luke.py
rename to tests/models/luke/test_modeling_luke.py
index 99a34cc81c0a23..264b7f89559d3e 100644
--- a/tests/luke/test_modeling_luke.py
+++ b/tests/models/luke/test_modeling_luke.py
@@ -18,8 +18,8 @@
from transformers import LukeConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -270,9 +270,12 @@ def create_and_check_for_masked_lm(
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
- self.parent.assertEqual(
- result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
- )
+ if entity_ids is not None:
+ self.parent.assertEqual(
+ result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
+ )
+ else:
+ self.parent.assertIsNone(result.entity_logits)
def create_and_check_for_entity_classification(
self,
@@ -488,6 +491,11 @@ def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+ def test_for_masked_lm_with_word_only(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
@@ -624,7 +632,10 @@ def test_inference_base_model(self):
model.to(torch_device)
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
@@ -656,7 +667,10 @@ def test_inference_large_model(self):
model.to(torch_device)
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large", task="entity_classification")
- text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ text = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
diff --git a/tests/luke/test_tokenization_luke.py b/tests/models/luke/test_tokenization_luke.py
similarity index 96%
rename from tests/luke/test_tokenization_luke.py
rename to tests/models/luke/test_tokenization_luke.py
index 456246384c0556..aa208f950bf3e2 100644
--- a/tests/luke/test_tokenization_luke.py
+++ b/tests/models/luke/test_tokenization_luke.py
@@ -13,20 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from typing import Tuple
from transformers import AddedToken, LukeTokenizer
-from transformers.testing_utils import require_torch, slow
+from transformers.testing_utils import get_tests_dir, require_torch, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/vocab.json")
-SAMPLE_MERGE_FILE = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/merges.txt")
-SAMPLE_ENTITY_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_entity_vocab.json")
+SAMPLE_VOCAB = get_tests_dir("fixtures/vocab.json")
+SAMPLE_MERGE_FILE = get_tests_dir("fixtures/merges.txt")
+SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
@@ -482,7 +480,10 @@ def test_text_pair_padding_pytorch_tensors(self):
def test_entity_classification_no_padding_or_truncation(self):
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification")
- sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ sentence = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
span = (39, 42)
encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True)
@@ -493,7 +494,8 @@ def test_entity_classification_no_padding_or_truncation(self):
self.assertEqual(len(encoding["token_type_ids"]), 42)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.",
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous"
+ " netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][9:12], spaces_between_special_tokens=False), " she"
@@ -516,7 +518,10 @@ def test_entity_classification_padding_pytorch_tensors(self):
tokenizer = LukeTokenizer.from_pretrained(
"studio-ousia/luke-base", task="entity_classification", return_token_type_ids=True
)
- sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ sentence = (
+ "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped"
+ " the new world number one avoid a humiliating second- round exit at Wimbledon ."
+ )
# entity information
span = (39, 42)
diff --git a/tests/mobilebert/__init__.py b/tests/models/lxmert/__init__.py
similarity index 100%
rename from tests/mobilebert/__init__.py
rename to tests/models/lxmert/__init__.py
diff --git a/tests/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py
similarity index 99%
rename from tests/lxmert/test_modeling_lxmert.py
rename to tests/models/lxmert/test_modeling_lxmert.py
index f1209d132dc6c4..1c51d02e96b714 100644
--- a/tests/lxmert/test_modeling_lxmert.py
+++ b/tests/models/lxmert/test_modeling_lxmert.py
@@ -23,8 +23,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
+ fx_compatible = True
test_head_masking = False
test_pruning = False
test_torchscript = False
diff --git a/tests/lxmert/test_modeling_tf_lxmert.py b/tests/models/lxmert/test_modeling_tf_lxmert.py
similarity index 99%
rename from tests/lxmert/test_modeling_tf_lxmert.py
rename to tests/models/lxmert/test_modeling_tf_lxmert.py
index 19226545a9b9f7..7594f889189c88 100644
--- a/tests/lxmert/test_modeling_tf_lxmert.py
+++ b/tests/models/lxmert/test_modeling_tf_lxmert.py
@@ -22,8 +22,8 @@
from transformers import LxmertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/lxmert/test_tokenization_lxmert.py b/tests/models/lxmert/test_tokenization_lxmert.py
similarity index 97%
rename from tests/lxmert/test_tokenization_lxmert.py
rename to tests/models/lxmert/test_tokenization_lxmert.py
index 38b76074c6c6a8..76047b1f44bccb 100644
--- a/tests/lxmert/test_tokenization_lxmert.py
+++ b/tests/models/lxmert/test_tokenization_lxmert.py
@@ -21,7 +21,7 @@
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/mpnet/__init__.py b/tests/models/m2m_100/__init__.py
similarity index 100%
rename from tests/mpnet/__init__.py
rename to tests/models/m2m_100/__init__.py
diff --git a/tests/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py
similarity index 95%
rename from tests/m2m_100/test_modeling_m2m_100.py
rename to tests/models/m2m_100/test_modeling_m2m_100.py
index 52a677618abd84..0d5bdc3ca3037d 100644
--- a/tests/m2m_100/test_modeling_m2m_100.py
+++ b/tests/models/m2m_100/test_modeling_m2m_100.py
@@ -23,9 +23,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -354,7 +355,9 @@ def test_seq_to_seq_generation(self):
src_fr = [
"L'affaire NSA souligne l'absence totale de dƩbat sur le renseignement",
"Selon moi, il y a deux niveaux de rƩponse de la part du gouvernement franƧais.",
- "Lorsque FranƧois Hollande tĆ©lĆ©phone Ć Barack Obama ou quand le ministre des affaires Ć©trangĆØres Laurent Fabius convoque l'ambassadeur des Etats-Unis, ils rĆ©agissent Ć une vraie dĆ©couverte, qui est celle de l'ampleur de la surveillance amĆ©ricaine sur l'ensemble des communications en France.",
+ "Lorsque FranƧois Hollande tĆ©lĆ©phone Ć Barack Obama ou quand le ministre des affaires Ć©trangĆØres Laurent"
+ " Fabius convoque l'ambassadeur des Etats-Unis, ils rĆ©agissent Ć une vraie dĆ©couverte, qui est celle de"
+ " l'ampleur de la surveillance amƩricaine sur l'ensemble des communications en France.",
]
# The below article tests that we don't add any hypotheses outside of the top n_beams
@@ -370,7 +373,9 @@ def test_seq_to_seq_generation(self):
expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
- "When FranƧois Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S. Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all communications in France.",
+ "When FranƧois Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
+ " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
+ " communications in France.",
]
generated = tokenizer.batch_decode(
diff --git a/tests/m2m_100/test_tokenization_m2m_100.py b/tests/models/m2m_100/test_tokenization_m2m_100.py
similarity index 97%
rename from tests/m2m_100/test_tokenization_m2m_100.py
rename to tests/models/m2m_100/test_tokenization_m2m_100.py
index 35652f0cb3a1bd..729deb6cd4861b 100644
--- a/tests/m2m_100/test_tokenization_m2m_100.py
+++ b/tests/models/m2m_100/test_tokenization_m2m_100.py
@@ -12,26 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import tempfile
import unittest
-from os.path import dirname
from pathlib import Path
from shutil import copyfile
from transformers import M2M100Tokenizer, is_torch_available
-from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch, slow
+from transformers.testing_utils import (
+ get_tests_dir,
+ nested_simplify,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+ slow,
+)
from transformers.utils import is_sentencepiece_available
if is_sentencepiece_available():
from transformers.models.m2m_100.tokenization_m2m_100 import save_json, VOCAB_FILES_NAMES
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
if is_sentencepiece_available():
- SAMPLE_SP = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+ SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
diff --git a/tests/mt5/__init__.py b/tests/models/marian/__init__.py
similarity index 100%
rename from tests/mt5/__init__.py
rename to tests/models/marian/__init__.py
diff --git a/tests/marian/test_modeling_flax_marian.py b/tests/models/marian/test_modeling_flax_marian.py
similarity index 99%
rename from tests/marian/test_modeling_flax_marian.py
rename to tests/models/marian/test_modeling_flax_marian.py
index bfb0d273add1d3..4180eb565cf598 100644
--- a/tests/marian/test_modeling_flax_marian.py
+++ b/tests/models/marian/test_modeling_flax_marian.py
@@ -21,8 +21,8 @@
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/tests/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py
similarity index 99%
rename from tests/marian/test_modeling_marian.py
rename to tests/models/marian/test_modeling_marian.py
index 067144d8cdc32b..1039c4a51d4293 100644
--- a/tests/marian/test_modeling_marian.py
+++ b/tests/models/marian/test_modeling_marian.py
@@ -22,9 +22,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
diff --git a/tests/marian/test_modeling_tf_marian.py b/tests/models/marian/test_modeling_tf_marian.py
similarity index 99%
rename from tests/marian/test_modeling_tf_marian.py
rename to tests/models/marian/test_modeling_tf_marian.py
index eb4f24700ba16f..e62d7f0d35ccad 100644
--- a/tests/marian/test_modeling_tf_marian.py
+++ b/tests/models/marian/test_modeling_tf_marian.py
@@ -22,8 +22,8 @@
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
diff --git a/tests/marian/test_tokenization_marian.py b/tests/models/marian/test_tokenization_marian.py
similarity index 96%
rename from tests/marian/test_tokenization_marian.py
rename to tests/models/marian/test_tokenization_marian.py
index 6b6ee6c9662d0b..2cbc0b0a3fe7da 100644
--- a/tests/marian/test_tokenization_marian.py
+++ b/tests/models/marian/test_tokenization_marian.py
@@ -13,25 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import tempfile
import unittest
-from os.path import dirname
from pathlib import Path
from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer
-from transformers.testing_utils import require_sentencepiece, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, slow
from transformers.utils import is_sentencepiece_available, is_tf_available, is_torch_available
if is_sentencepiece_available():
from transformers.models.marian.tokenization_marian import VOCAB_FILES_NAMES, save_json
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_SP = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<"
diff --git a/tests/nystromformer/__init__.py b/tests/models/maskformer/__init__.py
similarity index 100%
rename from tests/nystromformer/__init__.py
rename to tests/models/maskformer/__init__.py
diff --git a/tests/maskformer/test_feature_extraction_maskformer.py b/tests/models/maskformer/test_feature_extraction_maskformer.py
similarity index 99%
rename from tests/maskformer/test_feature_extraction_maskformer.py
rename to tests/models/maskformer/test_feature_extraction_maskformer.py
index 259954643fc99b..461add8c035565 100644
--- a/tests/maskformer/test_feature_extraction_maskformer.py
+++ b/tests/models/maskformer/test_feature_extraction_maskformer.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py
similarity index 99%
rename from tests/maskformer/test_modeling_maskformer.py
rename to tests/models/maskformer/test_modeling_maskformer.py
index 43daf85ab1016b..bbc24719d753a5 100644
--- a/tests/maskformer/test_modeling_maskformer.py
+++ b/tests/models/maskformer/test_modeling_maskformer.py
@@ -24,8 +24,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin
if is_torch_available():
diff --git a/tests/openai/__init__.py b/tests/models/mbart/__init__.py
similarity index 100%
rename from tests/openai/__init__.py
rename to tests/models/mbart/__init__.py
diff --git a/tests/mbart/test_modeling_flax_mbart.py b/tests/models/mbart/test_modeling_flax_mbart.py
similarity index 99%
rename from tests/mbart/test_modeling_flax_mbart.py
rename to tests/models/mbart/test_modeling_flax_mbart.py
index bd235487d79855..1009dc95dd2a2f 100644
--- a/tests/mbart/test_modeling_flax_mbart.py
+++ b/tests/models/mbart/test_modeling_flax_mbart.py
@@ -21,8 +21,8 @@
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/tests/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py
similarity index 97%
rename from tests/mbart/test_modeling_mbart.py
rename to tests/models/mbart/test_modeling_mbart.py
index 2037ee79efb379..6a8eeed9fb41c3 100644
--- a/tests/mbart/test_modeling_mbart.py
+++ b/tests/models/mbart/test_modeling_mbart.py
@@ -23,9 +23,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -348,7 +349,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£a Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£a Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, 250004]
diff --git a/tests/mbart/test_modeling_tf_mbart.py b/tests/models/mbart/test_modeling_tf_mbart.py
similarity index 99%
rename from tests/mbart/test_modeling_tf_mbart.py
rename to tests/models/mbart/test_modeling_tf_mbart.py
index eec59fe1b5e0c2..559a44e5db6ae4 100644
--- a/tests/mbart/test_modeling_tf_mbart.py
+++ b/tests/models/mbart/test_modeling_tf_mbart.py
@@ -20,8 +20,8 @@
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
diff --git a/tests/mbart/test_tokenization_mbart.py b/tests/models/mbart/test_tokenization_mbart.py
similarity index 95%
rename from tests/mbart/test_tokenization_mbart.py
rename to tests/models/mbart/test_tokenization_mbart.py
index 3b842b60fb125a..e80531051b65cb 100644
--- a/tests/mbart/test_tokenization_mbart.py
+++ b/tests/models/mbart/test_tokenization_mbart.py
@@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import shutil
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
-from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
+from transformers.testing_utils import (
+ get_tests_dir,
+ nested_simplify,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+)
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
@@ -208,7 +213,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
diff --git a/tests/pegasus/__init__.py b/tests/models/mbart50/__init__.py
similarity index 100%
rename from tests/pegasus/__init__.py
rename to tests/models/mbart50/__init__.py
diff --git a/tests/mbart50/test_tokenization_mbart50.py b/tests/models/mbart50/test_tokenization_mbart50.py
similarity index 96%
rename from tests/mbart50/test_tokenization_mbart50.py
rename to tests/models/mbart50/test_tokenization_mbart50.py
index 3e39beb67be8cb..5a65d8856656dc 100644
--- a/tests/mbart50/test_tokenization_mbart50.py
+++ b/tests/models/mbart50/test_tokenization_mbart50.py
@@ -12,19 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import shutil
import tempfile
import unittest
-from os.path import dirname
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
-from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch, slow
+from transformers.testing_utils import (
+ get_tests_dir,
+ nested_simplify,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+ slow,
+)
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
from transformers.models.mbart.modeling_mbart import shift_tokens_right
@@ -198,7 +203,9 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- 'Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.',
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei"
+ ' pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu vor'
+ " face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
diff --git a/tests/perceiver/__init__.py b/tests/models/mctct/__init__.py
similarity index 100%
rename from tests/perceiver/__init__.py
rename to tests/models/mctct/__init__.py
diff --git a/tests/models/mctct/test_feature_extraction_mctct.py b/tests/models/mctct/test_feature_extraction_mctct.py
new file mode 100644
index 00000000000000..e0c77ad450fde6
--- /dev/null
+++ b/tests/models/mctct/test_feature_extraction_mctct.py
@@ -0,0 +1,274 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import random
+import unittest
+
+import numpy as np
+
+from transformers import is_speech_available
+from transformers.testing_utils import require_torch, require_torchaudio
+
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+
+
+if is_speech_available():
+ from transformers import MCTCTFeatureExtractor
+
+global_rng = random.Random()
+
+
+def floats_list(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ values = []
+ for _batch_idx in range(shape[0]):
+ values.append([])
+ for _ in range(shape[1]):
+ values[-1].append(rng.random() * scale)
+
+ return values
+
+
+@require_torch
+@require_torchaudio
+class MCTCTFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ min_seq_length=400,
+ max_seq_length=2000,
+ feature_size=24,
+ num_mel_bins=24,
+ padding_value=0.0,
+ sampling_rate=16_000,
+ return_attention_mask=True,
+ do_normalize=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.min_seq_length = min_seq_length
+ self.max_seq_length = max_seq_length
+ self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
+ self.feature_size = feature_size
+ self.num_mel_bins = num_mel_bins
+ self.padding_value = padding_value
+ self.sampling_rate = sampling_rate
+ self.return_attention_mask = return_attention_mask
+ self.do_normalize = do_normalize
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "feature_size": self.feature_size,
+ "num_mel_bins": self.num_mel_bins,
+ "padding_value": self.padding_value,
+ "sampling_rate": self.sampling_rate,
+ "return_attention_mask": self.return_attention_mask,
+ "do_normalize": self.do_normalize,
+ }
+
+ def prepare_inputs_for_common(self, equal_length=False, numpify=False):
+ def _flatten(list_of_lists):
+ return list(itertools.chain(*list_of_lists))
+
+ if equal_length:
+ speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
+ else:
+ # make sure that inputs increase in size
+ speech_inputs = [
+ floats_list((x, self.feature_size))
+ for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
+ ]
+ if numpify:
+ speech_inputs = [np.asarray(x) for x in speech_inputs]
+ return speech_inputs
+
+
+@require_torch
+@require_torchaudio
+class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
+
+ feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
+
+ def setUp(self):
+ self.feat_extract_tester = MCTCTFeatureExtractionTester(self)
+
+ def _check_zero_mean_unit_variance(self, input_vector):
+ self.assertTrue(np.all(np.mean(input_vector) < 1e-3))
+ self.assertTrue(np.all(np.abs(np.var(input_vector) - 1) < 1e-3))
+
+ def test_call(self):
+ # Tests that all call wrap to encode_plus and batch_encode_plus
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ # create three inputs of length 800, 1000, and 12000
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
+
+ # Test feature size
+ input_features = feature_extractor(np_speech_inputs, padding=True, return_tensors="np").input_features
+ self.assertTrue(input_features.ndim == 3)
+ self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
+
+ # Test not batched input
+ encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
+ encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
+ self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
+
+ # Test batched
+ encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
+ encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
+ for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
+ self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+
+ def test_cepstral_mean_and_variance_normalization(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+
+ paddings = ["longest", "max_length", "do_not_pad"]
+ max_lengths = [None, 16, None]
+ for max_length, padding in zip(max_lengths, paddings):
+ inputs = feature_extractor(
+ speech_inputs,
+ padding=padding,
+ max_length=max_length,
+ return_attention_mask=True,
+ truncation=max_length is not None, # reference to #16419
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = [np.sum(x) for x in attention_mask]
+ self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
+
+ def test_cepstral_mean_and_variance_normalization_np(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+
+ paddings = ["longest", "max_length", "do_not_pad"]
+ max_lengths = [None, 16, None]
+ for max_length, padding in zip(max_lengths, paddings):
+ inputs = feature_extractor(
+ speech_inputs,
+ max_length=max_length,
+ padding=padding,
+ return_tensors="np",
+ return_attention_mask=True,
+ truncation=max_length is not None,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = [np.sum(x) for x in attention_mask]
+
+ self._check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]])
+ self.assertTrue(input_features[0][fbank_feat_lengths[0] :].sum() < 1e-6)
+ self._check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]])
+ self.assertTrue(input_features[0][fbank_feat_lengths[1] :].sum() < 1e-6)
+ self._check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]])
+
+ def test_cepstral_mean_and_variance_normalization_trunc_max_length(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="max_length",
+ max_length=4,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="longest",
+ max_length=4,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ # make sure that if max_length < longest -> then pad to max_length
+ self.assertEqual(input_features.shape, (3, 4, 24))
+
+ speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
+ inputs = feature_extractor(
+ speech_inputs,
+ padding="longest",
+ max_length=16,
+ truncation=True,
+ return_tensors="np",
+ return_attention_mask=True,
+ )
+ input_features = inputs.input_features
+ attention_mask = inputs.attention_mask
+ fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
+
+ self._check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
+ self._check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
+ self._check_zero_mean_unit_variance(input_features[2])
+
+ # make sure that if max_length < longest -> then pad to max_length
+ self.assertEqual(input_features.shape, (3, 16, 24))
+
+ def test_double_precision_pad(self):
+ import torch
+
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
+ py_speech_inputs = np_speech_inputs.tolist()
+
+ for inputs in [py_speech_inputs, np_speech_inputs]:
+ np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_features.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_features.dtype == torch.float32)
+
+ def test_different_window(self):
+ import torch
+
+ init_dict = self.feat_extract_tester.prepare_feat_extract_dict()
+ init_dict["win_function"] = "hann_window"
+
+ feature_extractor = self.feature_extraction_class(**init_dict)
+ np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
+ py_speech_inputs = np_speech_inputs.tolist()
+
+ for inputs in [py_speech_inputs, np_speech_inputs]:
+ np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_features.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_features.dtype == torch.float32)
diff --git a/tests/models/mctct/test_modeling_mctct.py b/tests/models/mctct/test_modeling_mctct.py
new file mode 100644
index 00000000000000..ee4a9efc2fef7d
--- /dev/null
+++ b/tests/models/mctct/test_modeling_mctct.py
@@ -0,0 +1,647 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch MCTCT model. """
+
+import inspect
+import math
+import unittest
+
+from datasets import load_dataset
+
+from transformers import MCTCTConfig, is_torch_available
+from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import MCTCTForCTC, MCTCTModel, MCTCTProcessor
+
+
+class MCTCTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=10,
+ seq_length=40, # speech is longer
+ is_training=False,
+ vocab_size=32,
+ hidden_size=128 * 4,
+ num_hidden_layers=4,
+ intermediate_size=20,
+ num_attention_heads=4,
+ attention_head_dim=128,
+ max_position_embeddings=920,
+ layer_norm_eps=1e-5,
+ layerdrop=0.3,
+ hidden_act="relu",
+ initializer_range=0.02,
+ hidden_dropout_prob=0.3,
+ attention_probs_dropout_prob=0.3,
+ conv_glu_dim=1,
+ conv_dropout=0.3,
+ num_conv_layers=1,
+ conv_kernel=(7,),
+ conv_stride=(3,),
+ input_feat_per_channel=80,
+ input_channels=1,
+ conv_channels=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length # speech is longer
+ self.is_training = is_training
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+
+ self.attention_head_dim = attention_head_dim
+ self.max_position_embeddings = max_position_embeddings
+
+ self.layer_norm_eps = layer_norm_eps
+ self.layerdrop = layerdrop
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+
+ self.conv_glu_dim = conv_glu_dim
+ self.conv_dropout = conv_dropout
+ self.num_conv_layers = num_conv_layers
+ self.conv_kernel = conv_kernel
+ self.conv_stride = conv_stride
+ self.input_feat_per_channel = input_feat_per_channel
+ self.input_channels = input_channels
+ self.conv_channels = conv_channels
+
+ output_seq_length = self.seq_length
+ dilation = 1
+ for _, kernel_sz, stride in zip(range(self.num_conv_layers), self.conv_kernel, self.conv_stride):
+ padding = kernel_sz // 2
+ output_seq_length = output_seq_length + 2 * padding - dilation * (kernel_sz - 1) - 1
+ output_seq_length = torch.div(output_seq_length, stride, rounding_mode="trunc") + 1
+
+ self.output_seq_length = int(math.ceil(output_seq_length))
+ self.encoder_seq_length = self.output_seq_length
+
+ def prepare_config_and_inputs(self):
+ input_features = floats_tensor(
+ [self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size
+ )
+ attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device)
+
+ config = self.get_config()
+
+ return config, input_features, attention_mask
+
+ def get_config(self):
+ return MCTCTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ intermediate_size=self.intermediate_size,
+ num_attention_heads=self.num_attention_heads,
+ attention_head_dim=self.attention_head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ layer_norm_eps=self.layer_norm_eps,
+ layerdrop=self.layerdrop,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ conv_glu_dim=self.conv_glu_dim,
+ conv_dropout=self.conv_dropout,
+ num_conv_layers=self.num_conv_layers,
+ conv_kernel=self.conv_kernel,
+ conv_stride=self.conv_stride,
+ input_feat_per_channel=self.input_feat_per_channel,
+ input_channels=self.input_channels,
+ conv_channels=self.conv_channels,
+ )
+
+ def create_and_check_model(self, config, input_features, attention_mask):
+ model = MCTCTModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_for_ctc(self, config, input_features, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 2 * config.hidden_size
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_features, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
+ )
+
+ def create_and_check_batch_inference(self, config, input_features, *args):
+ # test does not pass for models making use of `group_norm`
+ # check: https://github.com/pytorch/fairseq/issues/3227
+ model = MCTCTModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ input_features = input_features[:3]
+ attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.bool)
+
+ input_lengths = [input_features.shape[-1] // i for i in [2, 2, 1]]
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0.0
+
+ batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
+
+ for i in range(input_features.shape[0]):
+ input_slice = input_features[i : i + 1, : input_lengths[i]]
+ output = model(input_slice).last_hidden_state
+
+ batch_output = batch_outputs[i : i + 1, : output.shape[1]]
+ self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
+
+ def check_ctc_loss(self, config, input_features, *args):
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_features = input_features[:3]
+
+ # input_features is a 2D window for each sequence
+ attention_mask = torch.ones(input_features.shape[:-1], device=torch_device, dtype=torch.long)
+
+ # -2 since input_features is a 2D window for each sequence in batch
+ input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_features, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_ctc_training(self, config, input_features, *args):
+ config.ctc_zero_infinity = True
+ model = MCTCTForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-2] // i for i in [2, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_features[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lenghts are at least
+ # one shorter than logit lenghts to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_features, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_features, *args):
+ model = MCTCTForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_features = input_features[:3]
+
+ input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_features.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_features, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_features, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_features": input_features, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class MCTCTModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = MCTCTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # MCTCT has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_features`
+ def test_forward_signature(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = [
+ "input_features",
+ "attention_mask",
+ "head_mask",
+ "output_attentions",
+ "output_hidden_states",
+ "return_dict",
+ ]
+ self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
+
+ # MCTCT cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # MCTCT has no inputs_embeds
+ def test_model_common_attributes(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+ config.layerdrop = 0.0
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ input_features = inputs_dict["input_features"]
+
+ input_lengths = torch.tensor(
+ [input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class MCTCTRobustModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (MCTCTForCTC, MCTCTModel) if is_torch_available() else ()
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = MCTCTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MCTCTConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_batched_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_batch_inference(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # MCTCT has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_features`
+ def test_forward_signature(self):
+ pass
+
+ # MCTCT cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # MCTCT has no inputs_embeds
+ # and thus the `get_input_embeddings` fn
+ # is not implemented
+ def test_model_common_attributes(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_features = inputs_dict["input_features"]
+
+ input_lengths = torch.tensor(
+ [input_features.shape[1] for _ in range(input_features.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_features.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = MCTCTModel.from_pretrained("speechbrain/m-ctc-t-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+@require_soundfile
+@slow
+class MCTCTModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").filter(
+ lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
+ )[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_ctc_normal(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
+ model.to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+ input_speech = self._load_datasamples(1)
+
+ input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = ["a man said to the universe, sir, i exist."]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_normal_batched(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large")
+ model.to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_features = inputs.input_features.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features, attention_mask=attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe, sir, i exist.",
+ '"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor."',
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_robust_batched(self):
+ model = MCTCTForCTC.from_pretrained("speechbrain/m-ctc-t-large").to(torch_device)
+ processor = MCTCTProcessor.from_pretrained("speechbrain/m-ctc-t-large", do_lower_case=True)
+
+ input_speech = self._load_datasamples(4)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
+
+ input_features = inputs.input_features.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_features, attention_mask=attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe, sir, i exist.",
+ '"sweat-covered brion\'s body, trickling into the tight-lowing clossa was the only germent huor." "',
+ "\"the cadona's chest still-dripping bloodthe acofis overstrained eyes, even the soring arena around him"
+ " with thousands of spectators retrivialities not worth-thinking about.",
+ "his instant panic was followed by a small sharp blow high on his chestr.",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/mctct/test_processor_mctct.py b/tests/models/mctct/test_processor_mctct.py
new file mode 100644
index 00000000000000..83201f410215c0
--- /dev/null
+++ b/tests/models/mctct/test_processor_mctct.py
@@ -0,0 +1,147 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+from transformers import MCTCTProcessor, is_speech_available, is_torch_available
+from transformers.file_utils import FEATURE_EXTRACTOR_NAME
+from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizer
+from transformers.testing_utils import require_torch, require_torchaudio
+
+
+if is_speech_available() and is_torch_available():
+ from transformers import MCTCTFeatureExtractor
+
+ from .test_feature_extraction_mctct import floats_list
+
+
+@require_torch
+@require_torchaudio
+class MCTCTProcessorTest(unittest.TestCase):
+ def setUp(self):
+ vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
+ vocab_tokens = dict(zip(vocab, range(len(vocab))))
+
+ self.add_kwargs_tokens_map = {
+ "pad_token": "",
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ }
+ feature_extractor_map = {
+ "feature_size": 1,
+ "padding_value": 0.0,
+ "sampling_rate": 16000,
+ "return_attention_mask": False,
+ "do_normalize": True,
+ }
+
+ self.tmpdirname = tempfile.mkdtemp()
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(vocab_tokens) + "\n")
+
+ with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
+ fp.write(json.dumps(feature_extractor_map) + "\n")
+
+ def get_tokenizer(self, **kwargs_init):
+ kwargs = self.add_kwargs_tokens_map.copy()
+ kwargs.update(kwargs_init)
+ return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return MCTCTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = MCTCTProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, MCTCTFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = MCTCTProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = MCTCTProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, MCTCTFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ raw_speech = floats_list((3, 1000))
+
+ input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
+ input_processor = processor(raw_speech, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "This is a test string"
+
+ with processor.as_target_processor():
+ encoded_processor = processor(input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
diff --git a/tests/phobert/__init__.py b/tests/models/megatron_bert/__init__.py
similarity index 100%
rename from tests/phobert/__init__.py
rename to tests/models/megatron_bert/__init__.py
diff --git a/tests/megatron_bert/test_modeling_megatron_bert.py b/tests/models/megatron_bert/test_modeling_megatron_bert.py
similarity index 99%
rename from tests/megatron_bert/test_modeling_megatron_bert.py
rename to tests/models/megatron_bert/test_modeling_megatron_bert.py
index 01b93bf13a22fc..4ea3ddcb7be006 100644
--- a/tests/megatron_bert/test_modeling_megatron_bert.py
+++ b/tests/models/megatron_bert/test_modeling_megatron_bert.py
@@ -23,8 +23,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/plbart/__init__.py b/tests/models/megatron_gpt2/__init__.py
similarity index 100%
rename from tests/plbart/__init__.py
rename to tests/models/megatron_gpt2/__init__.py
diff --git a/tests/megatron_gpt2/test_modeling_megatron_gpt2.py b/tests/models/megatron_gpt2/test_modeling_megatron_gpt2.py
similarity index 100%
rename from tests/megatron_gpt2/test_modeling_megatron_gpt2.py
rename to tests/models/megatron_gpt2/test_modeling_megatron_gpt2.py
diff --git a/tests/poolformer/__init__.py b/tests/models/mluke/__init__.py
similarity index 100%
rename from tests/poolformer/__init__.py
rename to tests/models/mluke/__init__.py
diff --git a/tests/mluke/test_tokenization_mluke.py b/tests/models/mluke/test_tokenization_mluke.py
similarity index 98%
rename from tests/mluke/test_tokenization_mluke.py
rename to tests/models/mluke/test_tokenization_mluke.py
index c1d5ef639949b0..681825c7dccf9d 100644
--- a/tests/mluke/test_tokenization_mluke.py
+++ b/tests/models/mluke/test_tokenization_mluke.py
@@ -14,19 +14,17 @@
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from typing import Tuple
from transformers.models.mluke.tokenization_mluke import MLukeTokenizer
-from transformers.testing_utils import require_torch, slow
+from transformers.testing_utils import get_tests_dir, require_torch, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
-SAMPLE_ENTITY_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_entity_vocab.json")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
+SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
@@ -367,7 +365,8 @@ def test_text_pair_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan ).",
+ " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan"
+ " ).",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
@@ -425,7 +424,8 @@ def test_text_pair_only_entity_spans_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan ).",
+ " ISO 639-3 uses the code fas for the dialects spoken across Iran and ć¢ćć¬ćć¹ćæć³ ( Afghanistan"
+ " ).",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
@@ -508,7 +508,8 @@ def test_entity_classification_no_padding_or_truncation(self):
self.assertEqual(len(encoding["token_type_ids"]), 23)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- " Japanese is anEast Asian languagespoken by about 128 million people, primarily in Japan.",
+ " Japanese is anEast Asian languagespoken by about 128 million people, primarily in"
+ " Japan.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][4:9], spaces_between_special_tokens=False),
@@ -561,7 +562,8 @@ def test_entity_pair_classification_no_padding_or_truncation(self):
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
- "Japaneseis an East Asian language spoken by about 128 million people, primarily inJapan.",
+ "Japaneseis an East Asian language spoken by about 128 million people, primarily"
+ " inJapan.",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:4], spaces_between_special_tokens=False),
diff --git a/tests/prophetnet/__init__.py b/tests/models/mobilebert/__init__.py
similarity index 100%
rename from tests/prophetnet/__init__.py
rename to tests/models/mobilebert/__init__.py
diff --git a/tests/mobilebert/test_modeling_mobilebert.py b/tests/models/mobilebert/test_modeling_mobilebert.py
similarity index 99%
rename from tests/mobilebert/test_modeling_mobilebert.py
rename to tests/models/mobilebert/test_modeling_mobilebert.py
index 99e8e683235e6a..04301962c3cdad 100644
--- a/tests/mobilebert/test_modeling_mobilebert.py
+++ b/tests/models/mobilebert/test_modeling_mobilebert.py
@@ -20,8 +20,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/mobilebert/test_modeling_tf_mobilebert.py b/tests/models/mobilebert/test_modeling_tf_mobilebert.py
similarity index 99%
rename from tests/mobilebert/test_modeling_tf_mobilebert.py
rename to tests/models/mobilebert/test_modeling_tf_mobilebert.py
index c0ddf043562fa9..9db55cec2d58cc 100644
--- a/tests/mobilebert/test_modeling_tf_mobilebert.py
+++ b/tests/models/mobilebert/test_modeling_tf_mobilebert.py
@@ -19,8 +19,8 @@
from transformers import MobileBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/models/mobilebert/test_tokenization_mobilebert.py b/tests/models/mobilebert/test_tokenization_mobilebert.py
new file mode 100644
index 00000000000000..395f4a2aab2cb9
--- /dev/null
+++ b/tests/models/mobilebert/test_tokenization_mobilebert.py
@@ -0,0 +1,345 @@
+# coding=utf-8
+# Copyright 2022 Leon Derczynski. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the MobileBERT tokenizer. """
+
+
+import os
+import unittest
+
+from transformers import MobileBertTokenizer, MobileBertTokenizerFast
+from transformers.models.bert.tokenization_bert import (
+ VOCAB_FILES_NAMES,
+ BasicTokenizer,
+ WordpieceTokenizer,
+ _is_control,
+ _is_punctuation,
+ _is_whitespace,
+)
+from transformers.testing_utils import require_tokenizers, slow
+
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
+
+
+# Copied from transformers.tests.models.bert.test_modeling_bert.py with Bert->MobileBert and pathfix
+@require_tokenizers
+class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = MobileBertTokenizer
+ rust_tokenizer_class = MobileBertTokenizerFast
+ test_rust_tokenizer = True
+ space_between_special_tokens = True
+ from_pretrained_filter = filter_non_english
+ pre_trained_model_path = "google/mobilebert-uncased"
+
+ def setUp(self):
+ super().setUp()
+
+ vocab_tokens = [
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[PAD]",
+ "[MASK]",
+ "want",
+ "##want",
+ "##ed",
+ "wa",
+ "un",
+ "runn",
+ "##ing",
+ ",",
+ "low",
+ "lowest",
+ ]
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
+
+ self.tokenizers_list = [
+ (tokenizer_def[0], self.pre_trained_model_path, tokenizer_def[2]) # else the 'google/' prefix is stripped
+ for tokenizer_def in self.tokenizers_list
+ ]
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "UNwant\u00E9d,running"
+ output_text = "unwanted, running"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file)
+
+ tokens = tokenizer.tokenize("UNwant\u00E9d,running")
+ self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ # With lower casing
+ tokenizer = self.get_tokenizer(do_lower_case=True)
+ rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_chinese(self):
+ tokenizer = BasicTokenizer()
+
+ self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
+
+ def test_basic_tokenizer_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hƤllo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
+
+ def test_basic_tokenizer_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_default(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_no_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HƤLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_respects_never_split_tokens(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
+ )
+
+ def test_wordpiece_tokenizer(self):
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
+
+ vocab = {}
+ for i, token in enumerate(vocab_tokens):
+ vocab[token] = i
+ tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
+
+ self.assertListEqual(tokenizer.tokenize(""), [])
+
+ self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
+
+ self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+
+ def test_is_whitespace(self):
+ self.assertTrue(_is_whitespace(" "))
+ self.assertTrue(_is_whitespace("\t"))
+ self.assertTrue(_is_whitespace("\r"))
+ self.assertTrue(_is_whitespace("\n"))
+ self.assertTrue(_is_whitespace("\u00A0"))
+
+ self.assertFalse(_is_whitespace("A"))
+ self.assertFalse(_is_whitespace("-"))
+
+ def test_is_control(self):
+ self.assertTrue(_is_control("\u0005"))
+
+ self.assertFalse(_is_control("A"))
+ self.assertFalse(_is_control(" "))
+ self.assertFalse(_is_control("\t"))
+ self.assertFalse(_is_control("\r"))
+
+ def test_is_punctuation(self):
+ self.assertTrue(_is_punctuation("-"))
+ self.assertTrue(_is_punctuation("$"))
+ self.assertTrue(_is_punctuation("`"))
+ self.assertTrue(_is_punctuation("."))
+
+ self.assertFalse(_is_punctuation("A"))
+ self.assertFalse(_is_punctuation(" "))
+
+ def test_clean_text(self):
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
+ self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])
+
+ self.assertListEqual(
+ [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
+ )
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("google/mobilebert-uncased")
+
+ text = tokenizer.encode("sequence builders", add_special_tokens=False)
+ text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
+
+ encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_sentence == [101] + text + [102]
+ assert encoded_pair == [101] + text + [102] + text_2 + [102]
+
+ def test_offsets_with_special_characters(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ sentence = f"A, naĆÆve {tokenizer_r.mask_token} AllenNLP sentence."
+ tokens = tokenizer_r.encode_plus(
+ sentence,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+
+ do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
+ expected_results = (
+ [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "A"),
+ ((1, 2), ","),
+ ((3, 5), "na"),
+ ((5, 6), "##ĆÆ"),
+ ((6, 8), "##ve"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "Allen"),
+ ((21, 23), "##NL"),
+ ((23, 24), "##P"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ if not do_lower_case
+ else [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "a"),
+ ((1, 2), ","),
+ ((3, 8), "naive"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "allen"),
+ ((21, 23), "##nl"),
+ ((23, 24), "##p"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ )
+
+ self.assertEqual(
+ [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
+ )
+ self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+
+ def test_change_tokenize_chinese_chars(self):
+ list_of_commun_chinese_char = ["ē", "äŗŗ", "ę"]
+ text_with_chinese_char = "".join(list_of_commun_chinese_char)
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ kwargs["tokenize_chinese_chars"] = True
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that each Chinese character is not preceded by "##"
+ self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
+ self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
+
+ kwargs["tokenize_chinese_chars"] = False
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that only the first Chinese character is not preceded by "##".
+ expected_tokens = [
+ f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
+ ]
+ self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
+ self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
diff --git a/tests/qdqbert/__init__.py b/tests/models/mpnet/__init__.py
similarity index 100%
rename from tests/qdqbert/__init__.py
rename to tests/models/mpnet/__init__.py
diff --git a/tests/mpnet/test_modeling_mpnet.py b/tests/models/mpnet/test_modeling_mpnet.py
similarity index 98%
rename from tests/mpnet/test_modeling_mpnet.py
rename to tests/models/mpnet/test_modeling_mpnet.py
index 5417313998bfb3..1e72870fdaddf1 100644
--- a/tests/mpnet/test_modeling_mpnet.py
+++ b/tests/models/mpnet/test_modeling_mpnet.py
@@ -19,8 +19,8 @@
from transformers import MPNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/mpnet/test_modeling_tf_mpnet.py b/tests/models/mpnet/test_modeling_tf_mpnet.py
similarity index 98%
rename from tests/mpnet/test_modeling_tf_mpnet.py
rename to tests/models/mpnet/test_modeling_tf_mpnet.py
index f9f9e2d51201b6..a0a4964d57e95a 100644
--- a/tests/mpnet/test_modeling_tf_mpnet.py
+++ b/tests/models/mpnet/test_modeling_tf_mpnet.py
@@ -19,8 +19,8 @@
from transformers import MPNetConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/mpnet/test_tokenization_mpnet.py b/tests/models/mpnet/test_tokenization_mpnet.py
similarity index 97%
rename from tests/mpnet/test_tokenization_mpnet.py
rename to tests/models/mpnet/test_tokenization_mpnet.py
index 4cc677397d24bf..f761b0280953de 100644
--- a/tests/mpnet/test_tokenization_mpnet.py
+++ b/tests/models/mpnet/test_tokenization_mpnet.py
@@ -21,7 +21,7 @@
from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer
from transformers.testing_utils import require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/rag/__init__.py b/tests/models/mt5/__init__.py
similarity index 100%
rename from tests/rag/__init__.py
rename to tests/models/mt5/__init__.py
diff --git a/tests/mt5/test_modeling_flax_mt5.py b/tests/models/mt5/test_modeling_flax_mt5.py
similarity index 100%
rename from tests/mt5/test_modeling_flax_mt5.py
rename to tests/models/mt5/test_modeling_flax_mt5.py
diff --git a/tests/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py
similarity index 100%
rename from tests/mt5/test_modeling_mt5.py
rename to tests/models/mt5/test_modeling_mt5.py
diff --git a/tests/mt5/test_modeling_tf_mt5.py b/tests/models/mt5/test_modeling_tf_mt5.py
similarity index 100%
rename from tests/mt5/test_modeling_tf_mt5.py
rename to tests/models/mt5/test_modeling_tf_mt5.py
diff --git a/tests/realm/__init__.py b/tests/models/nystromformer/__init__.py
similarity index 100%
rename from tests/realm/__init__.py
rename to tests/models/nystromformer/__init__.py
diff --git a/tests/nystromformer/test_modeling_nystromformer.py b/tests/models/nystromformer/test_modeling_nystromformer.py
similarity index 98%
rename from tests/nystromformer/test_modeling_nystromformer.py
rename to tests/models/nystromformer/test_modeling_nystromformer.py
index e3e962b3109351..b93c074bf68377 100644
--- a/tests/nystromformer/test_modeling_nystromformer.py
+++ b/tests/models/nystromformer/test_modeling_nystromformer.py
@@ -20,8 +20,8 @@
from transformers import AutoTokenizer, NystromformerConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/reformer/__init__.py b/tests/models/openai/__init__.py
similarity index 100%
rename from tests/reformer/__init__.py
rename to tests/models/openai/__init__.py
diff --git a/tests/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py
similarity index 98%
rename from tests/openai/test_modeling_openai.py
rename to tests/models/openai/test_modeling_openai.py
index 80babf5b517f9f..2ff935eef590b4 100644
--- a/tests/openai/test_modeling_openai.py
+++ b/tests/models/openai/test_modeling_openai.py
@@ -19,9 +19,9 @@
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
diff --git a/tests/openai/test_modeling_tf_openai.py b/tests/models/openai/test_modeling_tf_openai.py
similarity index 98%
rename from tests/openai/test_modeling_tf_openai.py
rename to tests/models/openai/test_modeling_tf_openai.py
index f74a85ee60d62c..7cdc2a8bb1879b 100644
--- a/tests/openai/test_modeling_tf_openai.py
+++ b/tests/models/openai/test_modeling_tf_openai.py
@@ -19,8 +19,8 @@
from transformers import OpenAIGPTConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/openai/test_tokenization_openai.py b/tests/models/openai/test_tokenization_openai.py
similarity index 98%
rename from tests/openai/test_tokenization_openai.py
rename to tests/models/openai/test_tokenization_openai.py
index a9ac22c9283d48..26030632918b7a 100644
--- a/tests/openai/test_tokenization_openai.py
+++ b/tests/models/openai/test_tokenization_openai.py
@@ -22,7 +22,7 @@
from transformers.models.openai.tokenization_openai import VOCAB_FILES_NAMES
from transformers.testing_utils import require_ftfy, require_spacy, require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/regnet/__init__.py b/tests/models/opt/__init__.py
similarity index 100%
rename from tests/regnet/__init__.py
rename to tests/models/opt/__init__.py
diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py
new file mode 100644
index 00000000000000..17dce9eace2dbf
--- /dev/null
+++ b/tests/models/opt/test_modeling_flax_opt.py
@@ -0,0 +1,406 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import timeout_decorator # noqa
+
+from transformers import OPTConfig, is_flax_available
+from transformers.testing_utils import require_flax, require_sentencepiece, slow
+
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+
+
+if is_flax_available():
+ import os
+
+ # The slow tests are often failing with OOM error on GPU
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+
+ import jax
+ import jax.numpy as jnp
+ from transformers import FlaxOPTForCausalLM, FlaxOPTModel, GPT2Tokenizer
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+
+
+@require_flax
+class FlaxOPTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ initializer_range=0.02,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.initializer_range = initializer_range
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs(self):
+ input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
+ input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
+
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ is_encoder_decoder=False,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ initializer_range=self.initializer_range,
+ use_cache=False,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def check_use_cache_forward(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4")
+
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ attention_mask=attention_mask,
+ past_key_values=outputs_cache.past_key_values,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids, attention_mask = (
+ inputs_dict["input_ids"],
+ inputs_dict["attention_mask"],
+ )
+
+ attention_mask_cache = jnp.concatenate(
+ [
+ attention_mask,
+ jnp.zeros((attention_mask.shape[0], max_length - attention_mask.shape[1])),
+ ],
+ axis=-1,
+ )
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask_cache,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ past_key_values=outputs_cache.past_key_values,
+ attention_mask=attention_mask_cache,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids, attention_mask=attention_mask)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+
+@require_flax
+class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
+ all_model_classes = (FlaxOPTModel, FlaxOPTForCausalLM) if is_flax_available() else ()
+ all_generative_model_classes = () if is_flax_available() else ()
+
+ def setUp(self):
+ self.model_tester = FlaxOPTModelTester(self)
+
+ def test_use_cache_forward(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
+
+ def test_use_cache_forward_with_attn_mask(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_class_name in self.all_model_classes:
+ model = model_class_name.from_pretrained("facebook/opt-125m")
+ input_ids = np.ones((1, 1)) * model.config.eos_token_id
+ outputs = model(input_ids)
+ self.assertIsNotNone(outputs)
+
+
+@require_sentencepiece
+@require_flax
+class FlaxOPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = FlaxOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = jnp.array([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ output = model(input_ids=input_ids).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = jnp.array(
+ [[-0.2867, -1.9256, -0.3062], [-1.2711, -0.1337, -0.1897], [0.4109, 0.1187, -1.3142]]
+ )
+ self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_flax
+@slow
+class FlaxOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = FlaxOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="jax", padding=True, add_special_tokens=False)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ logits_meta = jnp.array(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=4e-2))
+
+ model = jax.jit(model)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=4e-2))
+
+
+@require_flax
+@slow
+class FlaxOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want everyone",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ "Paris is the capital of France and Parisdylib",
+ "Computers and mobile phones have taken precedence over",
+ ]
+
+ predicted_outputs = []
+
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_jitted_batch_generation(self):
+ model_id = "facebook/opt-125m"
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to thank",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ ]
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ inputs = tokenizer(
+ [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ ],
+ return_tensors="jax",
+ padding=True,
+ )
+
+ jit_generate = jax.jit(model.generate)
+
+ output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
+
+ output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
+
+ self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
+
+ # TODO fix in the following PR
+ # def test_batch_generation(self):
+ # model_id = "facebook/opt-350m"
+
+ # tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ # model = FlaxOPTForCausalLM.from_pretrained(model_id)
+
+ # tokenizer.padding_side = "left"
+
+ # # use different length sentences to test batching
+ # sentences = [
+ # "Hello, my dog is a little",
+ # "Today, I",
+ # ]
+
+ # inputs = tokenizer(sentences, return_tensors="jax", padding=True)
+ # input_ids = inputs["input_ids"]
+
+ # outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False)
+
+ # inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids
+ # output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ # num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum()
+ # inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids
+ # output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ # batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
+ # non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True)
+ # padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True)
+
+ # expected_output_sentence = [
+ # "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ # "Today, I"
+ # # TODO fix this test in next PR
+ # # "Today, I was in the middle of a conversation with a friend about the",
+ # ]
+ # self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ # # TODO outputs will be similar, fix in next PR
+ # self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py
new file mode 100644
index 00000000000000..8018d05f090667
--- /dev/null
+++ b/tests/models/opt/test_modeling_opt.py
@@ -0,0 +1,430 @@
+# coding=utf-8
+# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch OPT model. """
+
+
+import copy
+import tempfile
+import unittest
+
+import timeout_decorator # noqa
+
+from transformers import OPTConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import GPT2Tokenizer, OPTForCausalLM, OPTModel
+
+
+def prepare_opt_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = input_ids.ne(config.pad_token_id)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ }
+
+
+class OPTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
+ 3,
+ )
+ input_ids[:, -1] = self.eos_token_id # Eos Token
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ config = self.get_config()
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids, decoder_input_ids)
+ return config, inputs_dict
+
+ def get_config(self):
+ return OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ is_encoder_decoder=False,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ )
+
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.max_position_embeddings = 100
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = OPTModel(config=config).to(torch_device).eval()
+
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+ head_mask = inputs_dict["head_mask"]
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = ids_tensor((self.batch_size, 3), 2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
+ "last_hidden_state"
+ ]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+
+@require_torch
+class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
+ is_encoder_decoder = False
+ fx_compatible = True
+ test_pruning = False
+ test_missing_keys = False
+
+ def setUp(self):
+ self.model_tester = OPTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OPTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_save_load_strict(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
+ self.assertEqual(info["missing_keys"], [])
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in (OPTModel,):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ model(**inputs)[0]
+
+ def test_generate_fp16(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+ input_ids = input_dict["input_ids"]
+ attention_mask = input_ids.ne(1).to(torch_device)
+ model = OPTForCausalLM(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(input_ids, attention_mask=attention_mask)
+ model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
+
+
+def assert_tensors_close(a, b, atol=1e-12, prefix=""):
+ """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
+ if a is None and b is None:
+ return True
+ try:
+ if torch.allclose(a, b, atol=atol):
+ return True
+ raise
+ except Exception:
+ pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
+ if a.numel() > 100:
+ msg = f"tensor values are {pct_different:.1%} percent different."
+ else:
+ msg = f"{a} != {b}"
+ if prefix:
+ msg = prefix + ": " + msg
+ raise AssertionError(msg)
+
+
+def _long_tensor(tok_lst):
+ return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
+
+
+@require_torch
+class OPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+
+ with torch.no_grad():
+ output = model(input_ids=input_ids).last_hidden_state
+
+ expected_shape = torch.Size((1, 11, 512))
+ self.assertEqual(output.shape, expected_shape)
+ # expected value works for CPU, as well as GPU (with TF32 disabled)
+ expected_slice = torch.tensor(
+ [
+ [-0.28726277, -1.9241608, -0.3058734],
+ [-1.2737825, -0.13332152, -0.18766522],
+ [0.41159445, 0.1191957, -1.3107123],
+ ],
+ device=torch_device,
+ )
+ assert_tensors_close(output[0, :3, :3], expected_slice, atol=5e-5)
+
+
+@require_torch
+@slow
+class OPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_load_model(self):
+ try:
+ _ = OPTForCausalLM.from_pretrained(self.path_model)
+ except BaseException:
+ self.fail("Failed loading model")
+
+ def test_logits(self):
+ model = OPTForCausalLM.from_pretrained(self.path_model)
+ model = model.eval()
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
+ # logits_meta = torch.load(self.path_logits_meta)
+ logits_meta = torch.Tensor(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ assert torch.allclose(logits, logits_meta, atol=1e-4)
+
+
+@slow
+class OPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want everyone",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ "Paris is the capital of France and Parisdylib",
+ "Computers and mobile phones have taken precedence over",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_batch_generation(self):
+ model_id = "facebook/opt-350m"
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+ model.to(torch_device)
+
+ tokenizer.padding_side = "left"
+
+ # use different length sentences to test batching
+ sentences = [
+ "Hello, my dog is a little",
+ "Today, I",
+ ]
+
+ inputs = tokenizer(sentences, return_tensors="pt", padding=True)
+ input_ids = inputs["input_ids"].to(torch_device)
+
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=inputs["attention_mask"].to(torch_device),
+ )
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
+ inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ "Today, I was in the middle of a conversation with a friend about the",
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = OPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py
new file mode 100644
index 00000000000000..d34d4f0fc8e6a2
--- /dev/null
+++ b/tests/models/opt/test_modeling_tf_opt.py
@@ -0,0 +1,414 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers import OPTConfig, is_tf_available
+from transformers.testing_utils import require_sentencepiece, require_tf, slow
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import GPT2Tokenizer, TFOPTForCausalLM, TFOPTModel
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@require_tf
+class TFOPTModelTester:
+ config_cls = OPTConfig
+ config_updates = {}
+ hidden_act = "gelu"
+
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs_for_common(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
+ eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
+ input_ids = tf.concat([input_ids, eos_tensor], axis=1)
+
+ config = self.config_cls(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ is_encoder_decoder=False,
+ **self.config_updates,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = TFOPTModel(config=config)
+ input_ids = inputs_dict["input_ids"]
+
+ input_ids = input_ids[:1, :]
+ attention_mask = inputs_dict["attention_mask"][:1, :]
+ self.batch_size = 1
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
+
+ # append to next input_ids and
+ next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
+ next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
+
+ self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
+
+ # select random slice
+ random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
+ output_from_past_slice = output_from_past[:, :, random_slice_idx]
+
+ # test that outputs are equal for slice
+ tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
+
+
+@require_tf
+class TFOPTModelTest(TFModelTesterMixin, unittest.TestCase):
+ all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else ()
+ all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else ()
+ is_encoder_decoder = False
+ test_pruning = False
+ test_onnx = False
+ onnx_min_opset = 10
+
+ def setUp(self):
+ self.model_tester = TFOPTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OPTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
+
+ if model_class in self.all_generative_model_classes:
+ x = model.get_output_embeddings()
+ assert isinstance(x, tf.keras.layers.Layer)
+ else:
+ x = model.get_output_embeddings()
+ assert x is None
+
+ def test_resize_token_embeddings(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def _get_word_embedding_weight(model, embedding_layer):
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ # Here we build the word embeddings weights if not exists.
+ # And then we retry to get the attribute once built.
+ model(model.dummy_inputs)
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ return None
+
+ for model_class in self.all_model_classes:
+ for size in [config.vocab_size - 10, config.vocab_size + 10]:
+ # build the embeddings
+ model = model_class(config=config)
+ old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # reshape the embeddings
+ model.resize_token_embeddings(size)
+ new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # check that the resized embeddings size matches the desired size.
+ assert_size = size if size is not None else config.vocab_size
+
+ self.assertEqual(new_input_embeddings.shape[0], assert_size)
+
+ # check that weights remain the same after resizing
+ models_equal = True
+ for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ if old_output_embeddings is not None and new_output_embeddings is not None:
+ self.assertEqual(new_output_embeddings.shape[0], assert_size)
+
+ models_equal = True
+ for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ def test_saved_model_creation(self):
+ # This test is too long (>30sec) and makes fail the CI
+ pass
+
+
+def _long_tensor(tok_lst):
+ return tf.constant(tok_lst, dtype=tf.int32)
+
+
+@require_tf
+class TFOPTHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
+ input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
+ batch_size = input_ids.shape[0]
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=24,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ return config, input_ids, batch_size
+
+
+@require_sentencepiece
+@require_tf
+class OPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = TFOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = tf.not_equal(input_ids, model.config.pad_token_id)
+ with tf.GradientTape():
+ output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = tf.constant(
+ [[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]]
+ )
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-3))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ output = xla_generate(input_ids, attention_mask)[0]
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_tf
+@slow
+class TFOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = TFOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="tf", padding=True, add_special_tokens=False)
+ logits = tf.math.reduce_mean(model(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ logits_meta = tf.constant(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ logits = tf.math.reduce_mean(xla_generate(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+
+@slow
+class TFOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want everyone",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ "Paris is the capital of France and Parisdylib",
+ "Computers and mobile phones have taken precedence over",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_batch_generation(self):
+ model_id = "facebook/opt-350m"
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ tokenizer.padding_side = "left"
+
+ # use different length sentences to test batching
+ sentences = [
+ "Hello, my dog is a little",
+ "Today, I",
+ ]
+
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ input_ids = inputs["input_ids"]
+
+ outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - tf.math.reduce_sum(
+ tf.cast(inputs["attention_mask"][-1], tf.int64)
+ )
+ inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ "Today, I was in the middle of a conversation with a friend about the",
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
diff --git a/tests/rembert/__init__.py b/tests/models/pegasus/__init__.py
similarity index 100%
rename from tests/rembert/__init__.py
rename to tests/models/pegasus/__init__.py
diff --git a/tests/pegasus/test_modeling_flax_pegasus.py b/tests/models/pegasus/test_modeling_flax_pegasus.py
similarity index 99%
rename from tests/pegasus/test_modeling_flax_pegasus.py
rename to tests/models/pegasus/test_modeling_flax_pegasus.py
index 8f5c010477eb58..61c356bfb0ced3 100644
--- a/tests/pegasus/test_modeling_flax_pegasus.py
+++ b/tests/models/pegasus/test_modeling_flax_pegasus.py
@@ -19,8 +19,8 @@
from transformers import PegasusConfig, PegasusTokenizer, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
diff --git a/tests/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py
similarity index 99%
rename from tests/pegasus/test_modeling_pegasus.py
rename to tests/models/pegasus/test_modeling_pegasus.py
index 090300de0c9b2f..d5e9d22df189ce 100644
--- a/tests/pegasus/test_modeling_pegasus.py
+++ b/tests/models/pegasus/test_modeling_pegasus.py
@@ -21,10 +21,10 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ..mbart.test_modeling_mbart import AbstractSeq2SeqIntegrationTest
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_resize_position_embeddings = True
test_pruning = False
test_missing_keys = False
diff --git a/tests/pegasus/test_modeling_tf_pegasus.py b/tests/models/pegasus/test_modeling_tf_pegasus.py
similarity index 98%
rename from tests/pegasus/test_modeling_tf_pegasus.py
rename to tests/models/pegasus/test_modeling_tf_pegasus.py
index dd2b8c7d6194fa..14fcce39a649ca 100644
--- a/tests/pegasus/test_modeling_tf_pegasus.py
+++ b/tests/models/pegasus/test_modeling_tf_pegasus.py
@@ -20,8 +20,8 @@
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
@@ -339,7 +339,8 @@ class TFPegasusIntegrationTests(unittest.TestCase):
""" The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """,
]
expected_text = [
- "California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.",
+ "California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to"
+ " reduce the risk of wildfires.",
'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.',
] # differs slightly from pytorch, likely due to numerical differences in linear layers
model_name = "google/pegasus-xsum"
diff --git a/tests/pegasus/test_tokenization_pegasus.py b/tests/models/pegasus/test_tokenization_pegasus.py
similarity index 94%
rename from tests/pegasus/test_tokenization_pegasus.py
rename to tests/models/pegasus/test_tokenization_pegasus.py
index 7634902b584e4f..d473725f9ae926 100644
--- a/tests/pegasus/test_tokenization_pegasus.py
+++ b/tests/models/pegasus/test_tokenization_pegasus.py
@@ -18,7 +18,7 @@
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
@@ -72,7 +72,10 @@ def test_vocab_size(self):
def test_mask_tokens_rust_pegasus(self):
rust_tokenizer = self.rust_tokenizer_class.from_pretrained(self.tmpdirname)
py_tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
- raw_input_str = "Let's see which is the better one It seems like this was important "
+ raw_input_str = (
+ "Let's see which is the better one It seems like this was important"
+ " "
+ )
rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
self.assertListEqual(py_ids, rust_ids)
@@ -158,7 +161,10 @@ def get_input_output_texts(self, tokenizer):
def test_mask_tokens_rust_pegasus(self):
rust_tokenizer = self.rust_tokenizer_class.from_pretrained(self.tmpdirname)
py_tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
- raw_input_str = "Let's see which is the better one [MASK] It seems like this [MASK] was important "
+ raw_input_str = (
+ "Let's see which is the better one [MASK] It seems like this [MASK] was important "
+ " "
+ )
rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0]
self.assertListEqual(py_ids, rust_ids)
@@ -198,7 +204,10 @@ def test_equivalence_to_orig_tokenizer(self):
tokenizer.tokenize(test_str)
"""
- test_str = "This is an example string that is used to test the original TF implementation against the HF implementation"
+ test_str = (
+ "This is an example string that is used to test the original TF implementation against the HF"
+ " implementation"
+ )
token_ids = self._large_tokenizer(test_str).input_ids
diff --git a/tests/resnet/__init__.py b/tests/models/perceiver/__init__.py
similarity index 100%
rename from tests/resnet/__init__.py
rename to tests/models/perceiver/__init__.py
diff --git a/tests/perceiver/test_modeling_perceiver.py b/tests/models/perceiver/test_modeling_perceiver.py
similarity index 98%
rename from tests/perceiver/test_modeling_perceiver.py
rename to tests/models/perceiver/test_modeling_perceiver.py
index a394b00852c1f5..5947a73a0e41bf 100644
--- a/tests/perceiver/test_modeling_perceiver.py
+++ b/tests/models/perceiver/test_modeling_perceiver.py
@@ -30,8 +30,8 @@
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -143,7 +143,7 @@ def prepare_config_and_inputs(self, model_class=None):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
if model_class is None or model_class.__name__ == "PerceiverModel":
- inputs = floats_tensor([self.batch_size, self.seq_length, config.d_model], self.vocab_size)
+ inputs = floats_tensor([self.batch_size, self.seq_length, config.d_model], scale=1.0)
return config, inputs, input_mask, sequence_labels, token_labels
elif model_class.__name__ in ["PerceiverForMaskedLM", "PerceiverForSequenceClassification"]:
inputs = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -542,9 +542,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. "
- f"Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. "
- f"Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -767,7 +770,10 @@ def test_problem_types(self):
@require_torch_multi_gpu
@unittest.skip(
- reason="Perceiver does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
+ reason=(
+ "Perceiver does not work with data parallel (DP) because of a bug in PyTorch:"
+ " https://github.com/pytorch/pytorch/issues/36035"
+ )
)
def test_multi_gpu_data_parallel_forward(self):
pass
diff --git a/tests/perceiver/test_tokenization_perceiver.py b/tests/models/perceiver/test_tokenization_perceiver.py
similarity index 99%
rename from tests/perceiver/test_tokenization_perceiver.py
rename to tests/models/perceiver/test_tokenization_perceiver.py
index 0b6b7d4c75a8b2..ca61e9c856f137 100644
--- a/tests/perceiver/test_tokenization_perceiver.py
+++ b/tests/models/perceiver/test_tokenization_perceiver.py
@@ -24,7 +24,7 @@
from transformers import AddedToken, BatchEncoding, PerceiverTokenizer
from transformers.utils import cached_property, is_tf_available, is_torch_available
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
if is_torch_available():
diff --git a/tests/roberta/__init__.py b/tests/models/phobert/__init__.py
similarity index 100%
rename from tests/roberta/__init__.py
rename to tests/models/phobert/__init__.py
diff --git a/tests/phobert/test_tokenization_phobert.py b/tests/models/phobert/test_tokenization_phobert.py
similarity index 97%
rename from tests/phobert/test_tokenization_phobert.py
rename to tests/models/phobert/test_tokenization_phobert.py
index 87bdb95c5290b9..de16c154c92524 100644
--- a/tests/phobert/test_tokenization_phobert.py
+++ b/tests/models/phobert/test_tokenization_phobert.py
@@ -18,7 +18,7 @@
from transformers.models.phobert.tokenization_phobert import VOCAB_FILES_NAMES, PhobertTokenizer
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class PhobertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/roformer/__init__.py b/tests/models/plbart/__init__.py
similarity index 100%
rename from tests/roformer/__init__.py
rename to tests/models/plbart/__init__.py
diff --git a/tests/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py
similarity index 99%
rename from tests/plbart/test_modeling_plbart.py
rename to tests/models/plbart/test_modeling_plbart.py
index 6a307e244e2a59..171531503d2d33 100644
--- a/tests/plbart/test_modeling_plbart.py
+++ b/tests/models/plbart/test_modeling_plbart.py
@@ -23,9 +23,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
diff --git a/tests/plbart/test_tokenization_plbart.py b/tests/models/plbart/test_tokenization_plbart.py
similarity index 97%
rename from tests/plbart/test_tokenization_plbart.py
rename to tests/models/plbart/test_tokenization_plbart.py
index d83964e86d8500..9aed6040f3fda8 100644
--- a/tests/plbart/test_tokenization_plbart.py
+++ b/tests/models/plbart/test_tokenization_plbart.py
@@ -12,17 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, PLBartTokenizer, is_torch_available
-from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
+from transformers.testing_utils import (
+ get_tests_dir,
+ nested_simplify,
+ require_sentencepiece,
+ require_tokenizers,
+ require_torch,
+)
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
diff --git a/tests/segformer/__init__.py b/tests/models/poolformer/__init__.py
similarity index 100%
rename from tests/segformer/__init__.py
rename to tests/models/poolformer/__init__.py
diff --git a/tests/poolformer/test_feature_extraction_poolformer.py b/tests/models/poolformer/test_feature_extraction_poolformer.py
similarity index 98%
rename from tests/poolformer/test_feature_extraction_poolformer.py
rename to tests/models/poolformer/test_feature_extraction_poolformer.py
index 5fd830a7ffbd90..bb65835d5dc1d7 100644
--- a/tests/poolformer/test_feature_extraction_poolformer.py
+++ b/tests/models/poolformer/test_feature_extraction_poolformer.py
@@ -20,7 +20,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/poolformer/test_modeling_poolformer.py b/tests/models/poolformer/test_modeling_poolformer.py
similarity index 97%
rename from tests/poolformer/test_modeling_poolformer.py
rename to tests/models/poolformer/test_modeling_poolformer.py
index 1c6ea9b0a24a14..7dc47d2c77f9a4 100644
--- a/tests/poolformer/test_modeling_poolformer.py
+++ b/tests/models/poolformer/test_modeling_poolformer.py
@@ -22,8 +22,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -142,6 +142,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ @unittest.skip(reason="PoolFormer does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip("PoolFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/sew/__init__.py b/tests/models/prophetnet/__init__.py
similarity index 100%
rename from tests/sew/__init__.py
rename to tests/models/prophetnet/__init__.py
diff --git a/tests/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py
similarity index 97%
rename from tests/prophetnet/test_modeling_prophetnet.py
rename to tests/models/prophetnet/test_modeling_prophetnet.py
index 17bf4523e75657..9ac8ea81e20a94 100644
--- a/tests/prophetnet/test_modeling_prophetnet.py
+++ b/tests/models/prophetnet/test_modeling_prophetnet.py
@@ -20,9 +20,9 @@
from transformers import ProphetNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -1226,7 +1226,15 @@ def test_cnndm_inference(self):
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-cnndm")
- ARTICLE_TO_SUMMARIZE = "USTC was founded in Beijing by the Chinese Academy of Sciences (CAS) in September 1958. The Director of CAS, Mr. Guo Moruo was appointed the first president of USTC. USTC's founding mission was to develop a high-level science and technology workforce, as deemed critical for development of China's economy, defense, and science and technology education. The establishment was hailed as \"A Major Event in the History of Chinese Education and Science.\" CAS has supported USTC by combining most of its institutes with the departments of the university. USTC is listed in the top 16 national key universities, becoming the youngest national key university.".lower()
+ ARTICLE_TO_SUMMARIZE = (
+ "USTC was founded in Beijing by the Chinese Academy of Sciences (CAS) in September 1958. The Director of"
+ " CAS, Mr. Guo Moruo was appointed the first president of USTC. USTC's founding mission was to develop a"
+ " high-level science and technology workforce, as deemed critical for development of China's economy,"
+ ' defense, and science and technology education. The establishment was hailed as "A Major Event in the'
+ ' History of Chinese Education and Science." CAS has supported USTC by combining most of its institutes'
+ " with the departments of the university. USTC is listed in the top 16 national key universities, becoming"
+ " the youngest national key university.".lower()
+ )
input_ids = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=511, return_tensors="pt").input_ids
input_ids = input_ids.to(torch_device)
@@ -1234,7 +1242,10 @@ def test_cnndm_inference(self):
summary_ids = model.generate(
input_ids, num_beams=4, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True
)
- EXPECTED_SUMMARIZE_512 = "us ##tc was founded by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc is listed in the top 16 national key universities ."
+ EXPECTED_SUMMARIZE_512 = (
+ "us ##tc was founded by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc is listed in the"
+ " top 16 national key universities ."
+ )
generated_titles = [
" ".join(tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True)) for g in summary_ids
]
@@ -1251,7 +1262,8 @@ def test_cnndm_inference(self):
EXPECTED_SUMMARIZE_100 = (
r"us ##tc was founded in beijing by the chinese academy of sciences ( cas ) in 1958 . [X_SEP] us ##tc "
"'"
- ' s founding mission was to develop a high - level science and technology workforce . [X_SEP] establishment hailed as " a major event in the history of chinese education and science "'
+ " s founding mission was to develop a high - level science and technology workforce . [X_SEP]"
+ ' establishment hailed as " a major event in the history of chinese education and science "'
)
generated_titles = [
" ".join(tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True)) for g in summary_ids
diff --git a/tests/prophetnet/test_tokenization_prophetnet.py b/tests/models/prophetnet/test_tokenization_prophetnet.py
similarity index 98%
rename from tests/prophetnet/test_tokenization_prophetnet.py
rename to tests/models/prophetnet/test_tokenization_prophetnet.py
index 270bbf53fdfc36..8d95eb310025d1 100644
--- a/tests/prophetnet/test_tokenization_prophetnet.py
+++ b/tests/models/prophetnet/test_tokenization_prophetnet.py
@@ -28,7 +28,7 @@
from transformers.models.prophetnet.tokenization_prophetnet import VOCAB_FILES_NAMES, ProphetNetTokenizer
from transformers.testing_utils import require_torch, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@@ -141,7 +141,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/sew_d/__init__.py b/tests/models/qdqbert/__init__.py
similarity index 100%
rename from tests/sew_d/__init__.py
rename to tests/models/qdqbert/__init__.py
diff --git a/tests/qdqbert/test_modeling_qdqbert.py b/tests/models/qdqbert/test_modeling_qdqbert.py
similarity index 99%
rename from tests/qdqbert/test_modeling_qdqbert.py
rename to tests/models/qdqbert/test_modeling_qdqbert.py
index 5e53e59126c789..82bf5e3e336457 100644
--- a/tests/qdqbert/test_modeling_qdqbert.py
+++ b/tests/models/qdqbert/test_modeling_qdqbert.py
@@ -21,8 +21,8 @@
from transformers import QDQBertConfig, is_torch_available
from transformers.testing_utils import require_pytorch_quantization, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/speech_encoder_decoder/__init__.py b/tests/models/rag/__init__.py
similarity index 100%
rename from tests/speech_encoder_decoder/__init__.py
rename to tests/models/rag/__init__.py
diff --git a/tests/rag/test_modeling_rag.py b/tests/models/rag/test_modeling_rag.py
similarity index 99%
rename from tests/rag/test_modeling_rag.py
rename to tests/models/rag/test_modeling_rag.py
index 6914318cfac201..80819663a107d1 100644
--- a/tests/rag/test_modeling_rag.py
+++ b/tests/models/rag/test_modeling_rag.py
@@ -20,7 +20,6 @@
import shutil
import tempfile
import unittest
-from os.path import dirname
from unittest.mock import patch
import numpy as np
@@ -30,6 +29,7 @@
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import (
+ get_tests_dir,
require_sentencepiece,
require_tokenizers,
require_torch,
@@ -46,7 +46,7 @@
TOLERANCE = 1e-3
-T5_SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+T5_SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available() and is_datasets_available() and is_faiss_available():
import torch
from datasets import Dataset
diff --git a/tests/rag/test_modeling_tf_rag.py b/tests/models/rag/test_modeling_tf_rag.py
similarity index 98%
rename from tests/rag/test_modeling_tf_rag.py
rename to tests/models/rag/test_modeling_tf_rag.py
index d9050acb6311e9..314ce099baf65f 100644
--- a/tests/rag/test_modeling_tf_rag.py
+++ b/tests/models/rag/test_modeling_tf_rag.py
@@ -838,13 +838,6 @@ def test_rag_token_generate_batch(self):
input_ids = input_dict.input_ids
attention_mask = input_dict.attention_mask
- output_ids = rag_token.generate(
- input_ids,
- attention_mask=attention_mask,
- )
-
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
-
EXPECTED_OUTPUTS = [
" albert einstein",
" september 22, 2017",
@@ -855,7 +848,21 @@ def test_rag_token_generate_batch(self):
" 7.1. 2",
" 13",
]
- self.assertListEqual(outputs, EXPECTED_OUTPUTS)
+
+ # Split into 2 batches of 4 examples to avoid GPU OOM.
+ output_ids = rag_token.generate(
+ input_ids[:4],
+ attention_mask=attention_mask[:4],
+ )
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4])
+
+ output_ids = rag_token.generate(
+ input_ids[4:],
+ attention_mask=attention_mask[4:],
+ )
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+ self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])
@slow
def test_rag_sequence_generate_batch(self):
diff --git a/tests/rag/test_retrieval_rag.py b/tests/models/rag/test_retrieval_rag.py
similarity index 100%
rename from tests/rag/test_retrieval_rag.py
rename to tests/models/rag/test_retrieval_rag.py
diff --git a/tests/rag/test_tokenization_rag.py b/tests/models/rag/test_tokenization_rag.py
similarity index 100%
rename from tests/rag/test_tokenization_rag.py
rename to tests/models/rag/test_tokenization_rag.py
diff --git a/tests/speech_to_text/__init__.py b/tests/models/realm/__init__.py
similarity index 100%
rename from tests/speech_to_text/__init__.py
rename to tests/models/realm/__init__.py
diff --git a/tests/realm/test_modeling_realm.py b/tests/models/realm/test_modeling_realm.py
similarity index 99%
rename from tests/realm/test_modeling_realm.py
rename to tests/models/realm/test_modeling_realm.py
index 02eaa6556e9f9f..e084cf5a4e18a3 100644
--- a/tests/realm/test_modeling_realm.py
+++ b/tests/models/realm/test_modeling_realm.py
@@ -22,8 +22,8 @@
from transformers import RealmConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/realm/test_retrieval_realm.py b/tests/models/realm/test_retrieval_realm.py
similarity index 100%
rename from tests/realm/test_retrieval_realm.py
rename to tests/models/realm/test_retrieval_realm.py
diff --git a/tests/realm/test_tokenization_realm.py b/tests/models/realm/test_tokenization_realm.py
similarity index 98%
rename from tests/realm/test_tokenization_realm.py
rename to tests/models/realm/test_tokenization_realm.py
index 6bc31eaa577693..2a065ceee66af6 100644
--- a/tests/realm/test_tokenization_realm.py
+++ b/tests/models/realm/test_tokenization_realm.py
@@ -28,7 +28,7 @@
from transformers.models.realm.tokenization_realm import RealmTokenizer
from transformers.testing_utils import require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin, filter_non_english
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
@@ -186,7 +186,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/speech_to_text_2/__init__.py b/tests/models/reformer/__init__.py
similarity index 100%
rename from tests/speech_to_text_2/__init__.py
rename to tests/models/reformer/__init__.py
diff --git a/tests/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py
similarity index 99%
rename from tests/reformer/test_modeling_reformer.py
rename to tests/models/reformer/test_modeling_reformer.py
index d0259bacae18ca..0e5a801e7efb7c 100644
--- a/tests/reformer/test_modeling_reformer.py
+++ b/tests/models/reformer/test_modeling_reformer.py
@@ -25,9 +25,9 @@
torch_device,
)
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -574,7 +574,10 @@ def test_reformer_model_fp16_generate(self):
@require_torch_multi_gpu
@unittest.skip(
- reason="Reformer does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
+ reason=(
+ "Reformer does not work with data parallel (DP) because of a bug in PyTorch:"
+ " https://github.com/pytorch/pytorch/issues/36035"
+ )
)
def test_multi_gpu_data_parallel_forward(self):
pass
diff --git a/tests/reformer/test_tokenization_reformer.py b/tests/models/reformer/test_tokenization_reformer.py
similarity index 95%
rename from tests/reformer/test_tokenization_reformer.py
rename to tests/models/reformer/test_tokenization_reformer.py
index 22e6e455e6f969..37ea66847f2d0c 100644
--- a/tests/reformer/test_tokenization_reformer.py
+++ b/tests/models/reformer/test_tokenization_reformer.py
@@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import SPIECE_UNDERLINE, ReformerTokenizer, ReformerTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -216,7 +214,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
108,
265,
diff --git a/tests/splinter/__init__.py b/tests/models/regnet/__init__.py
similarity index 100%
rename from tests/splinter/__init__.py
rename to tests/models/regnet/__init__.py
diff --git a/tests/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py
similarity index 97%
rename from tests/regnet/test_modeling_regnet.py
rename to tests/models/regnet/test_modeling_regnet.py
index 2660108e96ae0e..4879bf259efc2c 100644
--- a/tests/regnet/test_modeling_regnet.py
+++ b/tests/models/regnet/test_modeling_regnet.py
@@ -22,8 +22,8 @@
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -147,6 +147,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="RegNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/squeezebert/__init__.py b/tests/models/rembert/__init__.py
similarity index 100%
rename from tests/squeezebert/__init__.py
rename to tests/models/rembert/__init__.py
diff --git a/tests/rembert/test_modeling_rembert.py b/tests/models/rembert/test_modeling_rembert.py
similarity index 99%
rename from tests/rembert/test_modeling_rembert.py
rename to tests/models/rembert/test_modeling_rembert.py
index 94ec90497f4b8f..a3ffd6dfd5a164 100644
--- a/tests/rembert/test_modeling_rembert.py
+++ b/tests/models/rembert/test_modeling_rembert.py
@@ -20,8 +20,8 @@
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/rembert/test_modeling_tf_rembert.py b/tests/models/rembert/test_modeling_tf_rembert.py
similarity index 99%
rename from tests/rembert/test_modeling_tf_rembert.py
rename to tests/models/rembert/test_modeling_tf_rembert.py
index d5d52062e8c92f..6d4cf0a523b933 100644
--- a/tests/rembert/test_modeling_tf_rembert.py
+++ b/tests/models/rembert/test_modeling_tf_rembert.py
@@ -19,8 +19,8 @@
from transformers import RemBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/swin/__init__.py b/tests/models/resnet/__init__.py
similarity index 100%
rename from tests/swin/__init__.py
rename to tests/models/resnet/__init__.py
diff --git a/tests/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py
similarity index 97%
rename from tests/resnet/test_modeling_resnet.py
rename to tests/models/resnet/test_modeling_resnet.py
index 7a0d1ee473d7ea..83f08b68afb8be 100644
--- a/tests/resnet/test_modeling_resnet.py
+++ b/tests/models/resnet/test_modeling_resnet.py
@@ -22,8 +22,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -147,6 +147,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="ResNet does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/t5/__init__.py b/tests/models/retribert/__init__.py
similarity index 100%
rename from tests/t5/__init__.py
rename to tests/models/retribert/__init__.py
diff --git a/tests/models/retribert/test_tokenization_retribert.py b/tests/models/retribert/test_tokenization_retribert.py
new file mode 100644
index 00000000000000..e2bf4e61b1ac09
--- /dev/null
+++ b/tests/models/retribert/test_tokenization_retribert.py
@@ -0,0 +1,384 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the RetriBERT tokenizer. """
+
+
+import os
+import unittest
+
+from transformers import RetriBertTokenizer, RetriBertTokenizerFast
+from transformers.models.bert.tokenization_bert import (
+ VOCAB_FILES_NAMES,
+ BasicTokenizer,
+ WordpieceTokenizer,
+ _is_control,
+ _is_punctuation,
+ _is_whitespace,
+)
+from transformers.testing_utils import require_tokenizers, require_torch, slow
+
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
+
+
+# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
+@require_tokenizers
+class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ tokenizer_class = RetriBertTokenizer
+ test_slow_tokenizer = True
+ rust_tokenizer_class = RetriBertTokenizerFast
+ test_rust_tokenizer = True
+ space_between_special_tokens = True
+ from_pretrained_filter = filter_non_english
+
+ def setUp(self):
+ super().setUp()
+
+ vocab_tokens = [
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[PAD]",
+ "[MASK]",
+ "want",
+ "##want",
+ "##ed",
+ "wa",
+ "un",
+ "runn",
+ "##ing",
+ ",",
+ "low",
+ "lowest",
+ ]
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+ with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
+
+ def get_input_output_texts(self, tokenizer):
+ input_text = "UNwant\u00E9d,running"
+ output_text = "unwanted, running"
+ return input_text, output_text
+
+ def test_full_tokenizer(self):
+ tokenizer = self.tokenizer_class(self.vocab_file)
+
+ tokens = tokenizer.tokenize("UNwant\u00E9d,running")
+ self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
+ self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
+
+ def test_rust_and_python_full_tokenizers(self):
+ if not self.test_rust_tokenizer:
+ return
+
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ # With lower casing
+ tokenizer = self.get_tokenizer(do_lower_case=True)
+ rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
+
+ sequence = "UNwant\u00E9d,running"
+
+ tokens = tokenizer.tokenize(sequence)
+ rust_tokens = rust_tokenizer.tokenize(sequence)
+ self.assertListEqual(tokens, rust_tokens)
+
+ ids = tokenizer.encode(sequence, add_special_tokens=False)
+ rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
+ self.assertListEqual(ids, rust_ids)
+
+ rust_tokenizer = self.get_rust_tokenizer()
+ ids = tokenizer.encode(sequence)
+ rust_ids = rust_tokenizer.encode(sequence)
+ self.assertListEqual(ids, rust_ids)
+
+ def test_chinese(self):
+ tokenizer = BasicTokenizer()
+
+ self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
+
+ def test_basic_tokenizer_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hƤllo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
+
+ def test_basic_tokenizer_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_lower_strip_accents_default(self):
+ tokenizer = BasicTokenizer(do_lower_case=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
+ )
+ self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+
+ def test_basic_tokenizer_no_lower(self):
+ tokenizer = BasicTokenizer(do_lower_case=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_false(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HƤLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_no_lower_strip_accents_true(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHƤLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
+ )
+
+ def test_basic_tokenizer_respects_never_split_tokens(self):
+ tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
+
+ self.assertListEqual(
+ tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
+ )
+
+ def test_wordpiece_tokenizer(self):
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
+
+ vocab = {}
+ for i, token in enumerate(vocab_tokens):
+ vocab[token] = i
+ tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
+
+ self.assertListEqual(tokenizer.tokenize(""), [])
+
+ self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
+
+ self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+
+ def test_is_whitespace(self):
+ self.assertTrue(_is_whitespace(" "))
+ self.assertTrue(_is_whitespace("\t"))
+ self.assertTrue(_is_whitespace("\r"))
+ self.assertTrue(_is_whitespace("\n"))
+ self.assertTrue(_is_whitespace("\u00A0"))
+
+ self.assertFalse(_is_whitespace("A"))
+ self.assertFalse(_is_whitespace("-"))
+
+ def test_is_control(self):
+ self.assertTrue(_is_control("\u0005"))
+
+ self.assertFalse(_is_control("A"))
+ self.assertFalse(_is_control(" "))
+ self.assertFalse(_is_control("\t"))
+ self.assertFalse(_is_control("\r"))
+
+ def test_is_punctuation(self):
+ self.assertTrue(_is_punctuation("-"))
+ self.assertTrue(_is_punctuation("$"))
+ self.assertTrue(_is_punctuation("`"))
+ self.assertTrue(_is_punctuation("."))
+
+ self.assertFalse(_is_punctuation("A"))
+ self.assertFalse(_is_punctuation(" "))
+
+ def test_clean_text(self):
+ tokenizer = self.get_tokenizer()
+ rust_tokenizer = self.get_rust_tokenizer()
+
+ # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
+ self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])
+
+ self.assertListEqual(
+ [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
+ )
+
+ @slow
+ def test_sequence_builders(self):
+ tokenizer = self.tokenizer_class.from_pretrained("yjernite/retribert-base-uncased")
+
+ text = tokenizer.encode("sequence builders", add_special_tokens=False)
+ text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
+
+ encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
+ encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
+
+ assert encoded_sentence == [101] + text + [102]
+ assert encoded_pair == [101] + text + [102] + text_2 + [102]
+
+ def test_offsets_with_special_characters(self):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ sentence = f"A, naĆÆve {tokenizer_r.mask_token} AllenNLP sentence."
+ tokens = tokenizer_r.encode_plus(
+ sentence,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ return_offsets_mapping=True,
+ add_special_tokens=True,
+ )
+
+ do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
+ expected_results = (
+ [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "A"),
+ ((1, 2), ","),
+ ((3, 5), "na"),
+ ((5, 6), "##ĆÆ"),
+ ((6, 8), "##ve"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "Allen"),
+ ((21, 23), "##NL"),
+ ((23, 24), "##P"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ if not do_lower_case
+ else [
+ ((0, 0), tokenizer_r.cls_token),
+ ((0, 1), "a"),
+ ((1, 2), ","),
+ ((3, 8), "naive"),
+ ((9, 15), tokenizer_r.mask_token),
+ ((16, 21), "allen"),
+ ((21, 23), "##nl"),
+ ((23, 24), "##p"),
+ ((25, 33), "sentence"),
+ ((33, 34), "."),
+ ((0, 0), tokenizer_r.sep_token),
+ ]
+ )
+
+ self.assertEqual(
+ [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
+ )
+ self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+
+ def test_change_tokenize_chinese_chars(self):
+ list_of_commun_chinese_char = ["ē", "äŗŗ", "ę"]
+ text_with_chinese_char = "".join(list_of_commun_chinese_char)
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+
+ kwargs["tokenize_chinese_chars"] = True
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that each Chinese character is not preceded by "##"
+ self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
+ self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
+
+ kwargs["tokenize_chinese_chars"] = False
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+
+ ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
+ ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
+
+ tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
+ tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
+
+ # it is expected that only the first Chinese character is not preceded by "##".
+ expected_tokens = [
+ f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
+ ]
+ self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
+ self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
+
+ # RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
+ @require_torch
+ @slow
+ def test_torch_encode_plus_sent_to_model(self):
+ import torch
+
+ from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
+
+ MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
+
+ tokenizers = self.get_tokenizers(do_lower_case=False)
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+
+ if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
+ return
+
+ config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
+ config = config_class()
+
+ if config.is_encoder_decoder or config.pad_token_id is None:
+ return
+
+ model = model_class(config)
+
+ # The following test is different from the common's one
+ self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
+
+ # Build sequence
+ first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
+ sequence = " ".join(first_ten_tokens)
+ encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
+
+ # Ensure that the BatchEncoding.to() method works.
+ encoded_sequence.to(model.device)
+
+ batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
+ # This should not fail
+
+ with torch.no_grad(): # saves some time
+ # The following lines are different from the common's ones
+ model.embed_questions(**encoded_sequence)
+ model.embed_questions(**batch_encoded_sequence)
diff --git a/tests/tapas/__init__.py b/tests/models/roberta/__init__.py
similarity index 100%
rename from tests/tapas/__init__.py
rename to tests/models/roberta/__init__.py
diff --git a/tests/roberta/test_modeling_flax_roberta.py b/tests/models/roberta/test_modeling_flax_roberta.py
similarity index 86%
rename from tests/roberta/test_modeling_flax_roberta.py
rename to tests/models/roberta/test_modeling_flax_roberta.py
index db92868769f718..5bd8a56022ce8c 100644
--- a/tests/roberta/test_modeling_flax_roberta.py
+++ b/tests/models/roberta/test_modeling_flax_roberta.py
@@ -19,11 +19,12 @@
from transformers import RobertaConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.roberta.modeling_flax_roberta import (
+ FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
@@ -112,6 +113,22 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
+ def prepare_config_and_inputs_for_decoder(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, attention_mask = config_and_inputs
+
+ config.is_decoder = True
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+ encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
+
+ return (
+ config,
+ input_ids,
+ token_type_ids,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+
@require_flax
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
@@ -121,6 +138,7 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxRobertaModel,
+ FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
diff --git a/tests/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py
similarity index 98%
rename from tests/roberta/test_modeling_roberta.py
rename to tests/models/roberta/test_modeling_roberta.py
index ab92c9dfbd6585..7163a357021e5e 100644
--- a/tests/roberta/test_modeling_roberta.py
+++ b/tests/models/roberta/test_modeling_roberta.py
@@ -20,9 +20,9 @@
from transformers import RobertaConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -112,6 +112,11 @@ def get_config(self):
initializer_range=self.initializer_range,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/roberta/test_modeling_tf_roberta.py b/tests/models/roberta/test_modeling_tf_roberta.py
similarity index 99%
rename from tests/roberta/test_modeling_tf_roberta.py
rename to tests/models/roberta/test_modeling_tf_roberta.py
index 9771673d8748ec..f9408b84171d3c 100644
--- a/tests/roberta/test_modeling_tf_roberta.py
+++ b/tests/models/roberta/test_modeling_tf_roberta.py
@@ -19,8 +19,8 @@
from transformers import RobertaConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/roberta/test_tokenization_roberta.py b/tests/models/roberta/test_tokenization_roberta.py
similarity index 99%
rename from tests/roberta/test_tokenization_roberta.py
rename to tests/models/roberta/test_tokenization_roberta.py
index a898d9bf5f5cf1..46ce5983f08100 100644
--- a/tests/roberta/test_tokenization_roberta.py
+++ b/tests/models/roberta/test_tokenization_roberta.py
@@ -23,7 +23,7 @@
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
diff --git a/tests/tapex/__init__.py b/tests/models/roformer/__init__.py
similarity index 100%
rename from tests/tapex/__init__.py
rename to tests/models/roformer/__init__.py
diff --git a/tests/roformer/test_modeling_flax_roformer.py b/tests/models/roformer/test_modeling_flax_roformer.py
similarity index 98%
rename from tests/roformer/test_modeling_flax_roformer.py
rename to tests/models/roformer/test_modeling_flax_roformer.py
index 01b643e897fdce..d45c08efdbb394 100644
--- a/tests/roformer/test_modeling_flax_roformer.py
+++ b/tests/models/roformer/test_modeling_flax_roformer.py
@@ -19,7 +19,7 @@
from transformers import RoFormerConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py
similarity index 99%
rename from tests/roformer/test_modeling_roformer.py
rename to tests/models/roformer/test_modeling_roformer.py
index f5177a91d5a90b..b1d7f3d8a67c3f 100644
--- a/tests/roformer/test_modeling_roformer.py
+++ b/tests/models/roformer/test_modeling_roformer.py
@@ -20,8 +20,8 @@
from transformers import RoFormerConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/roformer/test_modeling_tf_roformer.py b/tests/models/roformer/test_modeling_tf_roformer.py
similarity index 99%
rename from tests/roformer/test_modeling_tf_roformer.py
rename to tests/models/roformer/test_modeling_tf_roformer.py
index 9a23ca3b83d222..d32d30ae8ad92c 100644
--- a/tests/roformer/test_modeling_tf_roformer.py
+++ b/tests/models/roformer/test_modeling_tf_roformer.py
@@ -19,8 +19,8 @@
from transformers import RoFormerConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/roformer/test_tokenization_roformer.py b/tests/models/roformer/test_tokenization_roformer.py
similarity index 94%
rename from tests/roformer/test_tokenization_roformer.py
rename to tests/models/roformer/test_tokenization_roformer.py
index e5db42890d992e..7546bc2e41ddd9 100644
--- a/tests/roformer/test_tokenization_roformer.py
+++ b/tests/models/roformer/test_tokenization_roformer.py
@@ -18,7 +18,7 @@
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
from transformers.testing_utils import require_rjieba, require_tokenizers
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_rjieba
@@ -71,3 +71,7 @@ def test_training_new_tokenizer(self):
# can't train new_tokenizer via Tokenizers lib
def test_training_new_tokenizer_with_special_tokens_change(self):
pass
+
+ # can't serialise custom PreTokenizer
+ def test_save_slow_from_fast_and_reload_fast(self):
+ pass
diff --git a/tests/transfo_xl/__init__.py b/tests/models/segformer/__init__.py
similarity index 100%
rename from tests/transfo_xl/__init__.py
rename to tests/models/segformer/__init__.py
diff --git a/tests/segformer/test_feature_extraction_segformer.py b/tests/models/segformer/test_feature_extraction_segformer.py
similarity index 99%
rename from tests/segformer/test_feature_extraction_segformer.py
rename to tests/models/segformer/test_feature_extraction_segformer.py
index c34cb6fc0df2b5..75083012d87595 100644
--- a/tests/segformer/test_feature_extraction_segformer.py
+++ b/tests/models/segformer/test_feature_extraction_segformer.py
@@ -22,7 +22,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/segformer/test_modeling_segformer.py b/tests/models/segformer/test_modeling_segformer.py
similarity index 99%
rename from tests/segformer/test_modeling_segformer.py
rename to tests/models/segformer/test_modeling_segformer.py
index 668298507871e3..9af59299f8ecea 100644
--- a/tests/segformer/test_modeling_segformer.py
+++ b/tests/models/segformer/test_modeling_segformer.py
@@ -22,8 +22,8 @@
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
diff --git a/tests/trocr/__init__.py b/tests/models/sew/__init__.py
similarity index 100%
rename from tests/trocr/__init__.py
rename to tests/models/sew/__init__.py
diff --git a/tests/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py
similarity index 99%
rename from tests/sew/test_modeling_sew.py
rename to tests/models/sew/test_modeling_sew.py
index e8b06610dfa79a..9df69f84677e54 100644
--- a/tests/sew/test_modeling_sew.py
+++ b/tests/models/sew/test_modeling_sew.py
@@ -23,8 +23,8 @@
from transformers import SEWConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -108,7 +108,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length // self.squeeze_factor
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
diff --git a/tests/unispeech/__init__.py b/tests/models/sew_d/__init__.py
similarity index 100%
rename from tests/unispeech/__init__.py
rename to tests/models/sew_d/__init__.py
diff --git a/tests/sew_d/test_modeling_sew_d.py b/tests/models/sew_d/test_modeling_sew_d.py
similarity index 99%
rename from tests/sew_d/test_modeling_sew_d.py
rename to tests/models/sew_d/test_modeling_sew_d.py
index 796bd8805e654c..334b10abf3ec8f 100644
--- a/tests/sew_d/test_modeling_sew_d.py
+++ b/tests/models/sew_d/test_modeling_sew_d.py
@@ -23,8 +23,8 @@
from transformers import SEWDConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -122,7 +122,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length // self.squeeze_factor
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
diff --git a/tests/unispeech_sat/__init__.py b/tests/models/speech_encoder_decoder/__init__.py
similarity index 100%
rename from tests/unispeech_sat/__init__.py
rename to tests/models/speech_encoder_decoder/__init__.py
diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
similarity index 86%
rename from tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
rename to tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
index 0549e650645238..432d16d3facd1a 100644
--- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
+++ b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
@@ -21,9 +21,10 @@
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
+from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
+from ..bert.test_modeling_flax_bert import FlaxBertModelTester
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
-from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
@@ -34,6 +35,7 @@
from flax.traverse_util import flatten_dict
from transformers import (
FlaxBartForCausalLM,
+ FlaxBertForCausalLM,
FlaxGPT2LMHeadModel,
FlaxSpeechEncoderDecoderModel,
FlaxWav2Vec2Model,
@@ -582,7 +584,7 @@ def get_pretrained_model_and_inputs(self):
"facebook/wav2vec2-large-lv60", "gpt2-medium"
)
batch_size = 13
- input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
+ input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
@@ -638,7 +640,7 @@ def test_flaxwav2vec2gpt2_pt_flax_equivalence(self):
# prepare inputs
batch_size = 13
- input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
+ input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
@@ -699,7 +701,7 @@ def get_pretrained_model_and_inputs(self):
"facebook/wav2vec2-large-lv60", "bart-large"
)
batch_size = 13
- input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
+ input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
@@ -753,6 +755,121 @@ def test_flaxwav2vec2bart_pt_flax_equivalence(self):
pt_model.to(torch_device)
pt_model.eval()
+ # prepare inputs
+ batch_size = 13
+ input_values = floats_tensor([batch_size, 512], scale=1.0)
+ attention_mask = random_attention_mask([batch_size, 512])
+ decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
+ decoder_attention_mask = random_attention_mask([batch_size, 4])
+ inputs_dict = {
+ "inputs": input_values,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+
+ flax_inputs = inputs_dict
+ pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
+
+ with torch.no_grad():
+ pt_outputs = pt_model(**pt_inputs)
+ pt_logits = pt_outputs.logits
+ pt_outputs = pt_outputs.to_tuple()
+
+ fx_outputs = fx_model(**inputs_dict)
+ fx_logits = fx_outputs.logits
+ fx_outputs = fx_outputs.to_tuple()
+
+ self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
+ self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
+
+ # PT -> Flax
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pt_model.save_pretrained(tmpdirname)
+ fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
+
+ fx_outputs_loaded = fx_model_loaded(**inputs_dict)
+ fx_logits_loaded = fx_outputs_loaded.logits
+ fx_outputs_loaded = fx_outputs_loaded.to_tuple()
+ self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
+ self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
+
+ # Flax -> PT
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ fx_model.save_pretrained(tmpdirname)
+ pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
+
+ pt_model_loaded.to(torch_device)
+ pt_model_loaded.eval()
+
+ with torch.no_grad():
+ pt_outputs_loaded = pt_model_loaded(**pt_inputs)
+ pt_logits_loaded = pt_outputs_loaded.logits
+ pt_outputs_loaded = pt_outputs_loaded.to_tuple()
+
+ self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
+ self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
+
+
+@require_flax
+class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
+ def get_pretrained_model_and_inputs(self):
+ model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
+ "facebook/wav2vec2-large-lv60", "bert-large-uncased"
+ )
+ batch_size = 13
+ input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
+ attention_mask = random_attention_mask([batch_size, 512])
+ decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
+ decoder_attention_mask = random_attention_mask([batch_size, 4])
+ inputs = {
+ "inputs": input_values,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+
+ return model, inputs
+
+ def get_encoder_decoder_model(self, config, decoder_config):
+ encoder_model = FlaxWav2Vec2Model(config)
+ decoder_model = FlaxBertForCausalLM(decoder_config)
+ return encoder_model, decoder_model
+
+ def prepare_config_and_inputs(self):
+ model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
+ model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
+ encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
+ decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
+ (config, inputs, attention_mask) = encoder_config_and_inputs
+ (
+ decoder_config,
+ decoder_input_ids,
+ decoder_attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ ) = decoder_config_and_inputs
+
+ # make sure that cross attention layers are added
+ decoder_config.add_cross_attention = True
+ return {
+ "config": config,
+ "inputs": inputs,
+ "attention_mask": attention_mask,
+ "decoder_config": decoder_config,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+
+ @slow
+ def test_flaxwav2vec2bert_pt_flax_equivalence(self):
+ pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
+ fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
+
+ pt_model.to(torch_device)
+ pt_model.eval()
+
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
diff --git a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
similarity index 98%
rename from tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
rename to tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
index c17792084d2f69..2d934744f9e424 100644
--- a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
+++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
@@ -20,10 +20,10 @@
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
+from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_bert import BertModelTester
from ..speech_to_text.test_modeling_speech_to_text import Speech2TextModelTester
from ..speech_to_text_2.test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
-from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..wav2vec2.test_modeling_wav2vec2 import Wav2Vec2ModelTester
@@ -425,7 +425,7 @@ def get_pretrained_model_and_inputs(self):
"facebook/wav2vec2-base-960h", "bert-base-cased"
)
batch_size = 13
- input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size)
+ input_values = floats_tensor([batch_size, 512], scale=1.0)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
@@ -489,7 +489,7 @@ def get_pretrained_model_and_inputs(self):
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
)
batch_size = 13
- input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size)
+ input_features = floats_tensor([batch_size, 7, 80], scale=1.0)
attention_mask = random_attention_mask([batch_size, 7])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
diff --git a/tests/van/__init__.py b/tests/models/speech_to_text/__init__.py
similarity index 100%
rename from tests/van/__init__.py
rename to tests/models/speech_to_text/__init__.py
diff --git a/tests/speech_to_text/test_feature_extraction_speech_to_text.py b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
similarity index 99%
rename from tests/speech_to_text/test_feature_extraction_speech_to_text.py
rename to tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
index 9d719e4e1bf91e..244b748c7139ff 100644
--- a/tests/speech_to_text/test_feature_extraction_speech_to_text.py
+++ b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
@@ -23,7 +23,7 @@
from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio
-from ..test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
diff --git a/tests/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py
similarity index 86%
rename from tests/speech_to_text/test_modeling_speech_to_text.py
rename to tests/models/speech_to_text/test_modeling_speech_to_text.py
index 82b1c74c59dcdb..a1a625a9b4033f 100644
--- a/tests/speech_to_text/test_modeling_speech_to_text.py
+++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py
@@ -17,6 +17,7 @@
import copy
import inspect
import os
+import pickle
import tempfile
import unittest
@@ -30,11 +31,11 @@
slow,
torch_device,
)
-from transformers.utils import cached_property
+from transformers.utils import cached_property, is_torch_fx_available
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
@@ -43,6 +44,9 @@
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
def prepare_speech_to_text_inputs_dict(
config,
@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
+ fx_compatible = True
test_pruning = False
test_missing_keys = False
@@ -715,6 +720,105 @@ def _create_and_check_torchscript(self, config, inputs_dict):
self.assertTrue(models_equal)
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "input_features",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
@require_torch
@require_torchaudio
@@ -770,8 +874,10 @@ def test_generation_librispeech_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister cultar's manner less interesting than his matter",
- "he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
- "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
+ "he tells us that at this festive season of the year with christmas and roast beef looming before us"
+ " similes drawn from eating and its results occur most readily to the mind",
+ "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it"
+ " but little of rocky ithaca",
]
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
similarity index 98%
rename from tests/speech_to_text/test_modeling_tf_speech_to_text.py
rename to tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
index 897c54722a2f80..613af6be0cd026 100644
--- a/tests/speech_to_text/test_modeling_tf_speech_to_text.py
+++ b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py
@@ -21,8 +21,8 @@
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
@@ -602,7 +602,9 @@ def test_generation_librispeech_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister cultar's manner less interesting than his matter",
- "he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
- "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
+ "he tells us that at this festive season of the year with christmas and roast beef looming before us"
+ " similes drawn from eating and its results occur most readily to the mind",
+ "he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it"
+ " but little of rocky ithaca",
]
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/speech_to_text/test_processor_speech_to_text.py b/tests/models/speech_to_text/test_processor_speech_to_text.py
similarity index 96%
rename from tests/speech_to_text/test_processor_speech_to_text.py
rename to tests/models/speech_to_text/test_processor_speech_to_text.py
index 05871a2bb06838..e6e43f1bb8d7ed 100644
--- a/tests/speech_to_text/test_processor_speech_to_text.py
+++ b/tests/models/speech_to_text/test_processor_speech_to_text.py
@@ -12,17 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import shutil
import tempfile
import unittest
-from os.path import dirname
from pathlib import Path
from shutil import copyfile
from transformers import Speech2TextTokenizer, is_speech_available
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
-from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio
from transformers.utils import FEATURE_EXTRACTOR_NAME
from .test_feature_extraction_speech_to_text import floats_list
@@ -32,7 +30,7 @@
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
-SAMPLE_SP = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
@require_torch
diff --git a/tests/speech_to_text/test_tokenization_speech_to_text.py b/tests/models/speech_to_text/test_tokenization_speech_to_text.py
similarity index 97%
rename from tests/speech_to_text/test_tokenization_speech_to_text.py
rename to tests/models/speech_to_text/test_tokenization_speech_to_text.py
index 43f26092ff61eb..3b2ef9f456f401 100644
--- a/tests/speech_to_text/test_tokenization_speech_to_text.py
+++ b/tests/models/speech_to_text/test_tokenization_speech_to_text.py
@@ -12,21 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from pathlib import Path
from shutil import copyfile
from transformers import SPIECE_UNDERLINE, is_sentencepiece_available
from transformers.models.speech_to_text import Speech2TextTokenizer
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_SP = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
if is_sentencepiece_available():
import sentencepiece as sp
diff --git a/tests/vilt/__init__.py b/tests/models/speech_to_text_2/__init__.py
similarity index 100%
rename from tests/vilt/__init__.py
rename to tests/models/speech_to_text_2/__init__.py
diff --git a/tests/speech_to_text_2/test_modeling_speech_to_text_2.py b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
similarity index 97%
rename from tests/speech_to_text_2/test_modeling_speech_to_text_2.py
rename to tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
index 861e4acedc6ecf..d9717b406049ab 100644
--- a/tests/speech_to_text_2/test_modeling_speech_to_text_2.py
+++ b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py
@@ -19,9 +19,9 @@
from transformers import Speech2Text2Config
from transformers.testing_utils import is_torch_available, require_torch, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -179,6 +179,7 @@ def prepare_config_and_inputs_for_common(self):
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
def setUp(
diff --git a/tests/speech_to_text_2/test_tokenization_speech_to_text_2.py b/tests/models/speech_to_text_2/test_tokenization_speech_to_text_2.py
similarity index 98%
rename from tests/speech_to_text_2/test_tokenization_speech_to_text_2.py
rename to tests/models/speech_to_text_2/test_tokenization_speech_to_text_2.py
index 072473851fc349..1000cce2898036 100644
--- a/tests/speech_to_text_2/test_tokenization_speech_to_text_2.py
+++ b/tests/models/speech_to_text_2/test_tokenization_speech_to_text_2.py
@@ -21,7 +21,7 @@
from transformers.models.speech_to_text_2 import Speech2Text2Tokenizer
from transformers.models.speech_to_text_2.tokenization_speech_to_text_2 import VOCAB_FILES_NAMES
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/vision_encoder_decoder/__init__.py b/tests/models/splinter/__init__.py
similarity index 100%
rename from tests/vision_encoder_decoder/__init__.py
rename to tests/models/splinter/__init__.py
diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py
new file mode 100644
index 00000000000000..bc355bd2cd0719
--- /dev/null
+++ b/tests/models/splinter/test_modeling_splinter.py
@@ -0,0 +1,457 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Splinter model. """
+
+import copy
+import unittest
+
+from transformers import is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import SplinterConfig, SplinterForPreTraining, SplinterForQuestionAnswering, SplinterModel
+ from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+class SplinterModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ num_questions=3,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ question_token_id=1,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_questions = num_questions
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.question_token_id = question_token_id
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_ids[:, 1] = self.question_token_id
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ start_positions = None
+ end_positions = None
+ question_positions = None
+ if self.use_labels:
+ start_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
+ end_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
+ question_positions = ids_tensor([self.batch_size, self.num_questions], self.num_labels)
+
+ config = SplinterConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ question_token_id=self.question_token_id,
+ )
+
+ return (config, input_ids, token_type_ids, input_mask, start_positions, end_positions, question_positions)
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
+ ):
+ model = SplinterModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
+ result = model(input_ids, token_type_ids=token_type_ids)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_question_answering(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
+ ):
+ model = SplinterForQuestionAnswering(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=start_positions[:, 0],
+ end_positions=end_positions[:, 0],
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
+
+ def create_and_check_for_pretraining(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
+ ):
+ model = SplinterForPreTraining(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
+ self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ start_positions,
+ end_positions,
+ question_positions,
+ ) = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "token_type_ids": token_type_ids,
+ "attention_mask": input_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ SplinterModel,
+ SplinterForQuestionAnswering,
+ SplinterForPreTraining,
+ )
+ if is_torch_available()
+ else ()
+ )
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = copy.deepcopy(inputs_dict)
+ if return_labels:
+ if issubclass(model_class, SplinterForPreTraining):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ inputs_dict["question_positions"] = torch.zeros(
+ self.model_tester.batch_size,
+ self.model_tester.num_questions,
+ dtype=torch.long,
+ device=torch_device,
+ )
+ elif issubclass(model_class, SplinterForQuestionAnswering):
+ inputs_dict["start_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+ inputs_dict["end_positions"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = SplinterModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_various_embeddings(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ for type in ["absolute", "relative_key", "relative_key_query"]:
+ config_and_inputs[0].position_embedding_type = type
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ def test_for_pretraining(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ if not self.is_encoder_decoder:
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+ else:
+ encoder_input_ids = inputs["input_ids"]
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
+ del inputs["input_ids"]
+ inputs.pop("decoder_input_ids", None)
+
+ wte = model.get_input_embeddings()
+ if not self.is_encoder_decoder:
+ inputs["inputs_embeds"] = wte(input_ids)
+ else:
+ inputs["inputs_embeds"] = wte(encoder_input_ids)
+ inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
+
+ with torch.no_grad():
+ if isinstance(model, SplinterForPreTraining):
+ with self.assertRaises(TypeError):
+ # question_positions must not be None.
+ model(**inputs)[0]
+ else:
+ model(**inputs)[0]
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = SplinterModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class SplinterModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_splinter_question_answering(self):
+ model = SplinterForQuestionAnswering.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] Brad was born in [QUESTION] . He returned to the United Kingdom later . [SEP]"
+ # Output should be the span "the United Kingdom"
+ input_ids = torch.tensor(
+ [[101, 7796, 1108, 1255, 1107, 104, 119, 1124, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
+ )
+ output = model(input_ids)
+
+ expected_shape = torch.Size((1, 16))
+ self.assertEqual(output.start_logits.shape, expected_shape)
+ self.assertEqual(output.end_logits.shape, expected_shape)
+
+ self.assertEqual(torch.argmax(output.start_logits), 10)
+ self.assertEqual(torch.argmax(output.end_logits), 12)
+
+ @slow
+ def test_splinter_pretraining(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
+ )
+ question_positions = torch.tensor([[1, 5]], dtype=torch.long)
+ output = model(input_ids, question_positions=question_positions)
+
+ expected_shape = torch.Size((1, 2, 16))
+ self.assertEqual(output.start_logits.shape, expected_shape)
+ self.assertEqual(output.end_logits.shape, expected_shape)
+
+ self.assertEqual(torch.argmax(output.start_logits[0, 0]), 7)
+ self.assertEqual(torch.argmax(output.end_logits[0, 0]), 7)
+ self.assertEqual(torch.argmax(output.start_logits[0, 1]), 10)
+ self.assertEqual(torch.argmax(output.end_logits[0, 1]), 12)
+
+ @slow
+ def test_splinter_pretraining_loss_requires_question_positions(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
+ )
+ start_positions = torch.tensor([[7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([7, 12], dtype=torch.long)
+ with self.assertRaises(TypeError):
+ model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ )
+
+ @slow
+ def test_splinter_pretraining_loss(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ ]
+ )
+ start_positions = torch.tensor([[7, 10], [7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([[7, 12], [7, 12]], dtype=torch.long)
+ question_positions = torch.tensor([[1, 5], [1, 5]], dtype=torch.long)
+ output = model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ self.assertAlmostEqual(output.loss.item(), 0.0024, 4)
+
+ @slow
+ def test_splinter_pretraining_loss_with_padding(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
+ # Output should be the spans "Brad" and "the United Kingdom"
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
+ ]
+ )
+ start_positions = torch.tensor([[7, 10]], dtype=torch.long)
+ end_positions = torch.tensor([7, 12], dtype=torch.long)
+ question_positions = torch.tensor([[1, 5]], dtype=torch.long)
+ start_positions_with_padding = torch.tensor([[7, 10, 0]], dtype=torch.long)
+ end_positions_with_padding = torch.tensor([7, 12, 0], dtype=torch.long)
+ question_positions_with_padding = torch.tensor([[1, 5, 0]], dtype=torch.long)
+ output = model(
+ input_ids,
+ start_positions=start_positions,
+ end_positions=end_positions,
+ question_positions=question_positions,
+ )
+ output_with_padding = model(
+ input_ids,
+ start_positions=start_positions_with_padding,
+ end_positions=end_positions_with_padding,
+ question_positions=question_positions_with_padding,
+ )
+
+ self.assertAlmostEqual(output.loss.item(), output_with_padding.loss.item(), 4)
+
+ # Note that the original code uses 0 to denote padded question tokens
+ # and their start and end positions. As the pad_token_id of the model's
+ # config is used for the losse's ignore_index in SplinterForPreTraining,
+ # we add this test to ensure anybody making changes to the default
+ # value of the config, will be aware of the implication.
+ self.assertEqual(model.config.pad_token_id, 0)
+
+ @slow
+ def test_splinter_pretraining_prepare_question_positions(self):
+ model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
+
+ input_ids = torch.tensor(
+ [
+ [101, 104, 1, 2, 104, 3, 4, 102],
+ [101, 1, 104, 2, 104, 3, 104, 102],
+ [101, 1, 2, 104, 104, 3, 4, 102],
+ [101, 1, 2, 3, 4, 5, 104, 102],
+ ]
+ )
+ question_positions = torch.tensor([[1, 4, 0], [2, 4, 6], [3, 4, 0], [6, 0, 0]], dtype=torch.long)
+ output_without_positions = model(input_ids)
+ output_with_positions = model(input_ids, question_positions=question_positions)
+ self.assertTrue((output_without_positions.start_logits == output_with_positions.start_logits).all())
+ self.assertTrue((output_without_positions.end_logits == output_with_positions.end_logits).all())
diff --git a/tests/vision_text_dual_encoder/__init__.py b/tests/models/squeezebert/__init__.py
similarity index 100%
rename from tests/vision_text_dual_encoder/__init__.py
rename to tests/models/squeezebert/__init__.py
diff --git a/tests/squeezebert/test_modeling_squeezebert.py b/tests/models/squeezebert/test_modeling_squeezebert.py
similarity index 98%
rename from tests/squeezebert/test_modeling_squeezebert.py
rename to tests/models/squeezebert/test_modeling_squeezebert.py
index c728aa2b0c2ce4..cffc4570a05918 100644
--- a/tests/squeezebert/test_modeling_squeezebert.py
+++ b/tests/models/squeezebert/test_modeling_squeezebert.py
@@ -19,8 +19,8 @@
from transformers import SqueezeBertConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/squeezebert/test_tokenization_squeezebert.py b/tests/models/squeezebert/test_tokenization_squeezebert.py
similarity index 100%
rename from tests/squeezebert/test_tokenization_squeezebert.py
rename to tests/models/squeezebert/test_tokenization_squeezebert.py
diff --git a/tests/visual_bert/__init__.py b/tests/models/swin/__init__.py
similarity index 100%
rename from tests/visual_bert/__init__.py
rename to tests/models/swin/__init__.py
diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py
new file mode 100644
index 00000000000000..0c1f266816c7c6
--- /dev/null
+++ b/tests/models/swin/test_modeling_swin.py
@@ -0,0 +1,502 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Swin model. """
+
+import inspect
+import os
+import pickle
+import tempfile
+import unittest
+
+from transformers import SwinConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
+ from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
+
+class SwinModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ embed_dim=16,
+ depths=[1, 2, 1],
+ num_heads=[2, 2, 4],
+ window_size=2,
+ mlp_ratio=2.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ is_training=True,
+ scope=None,
+ use_labels=True,
+ type_sequence_label_size=10,
+ encoder_stride=8,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.patch_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.is_training = is_training
+ self.scope = scope
+ self.use_labels = use_labels
+ self.type_sequence_label_size = type_sequence_label_size
+ self.encoder_stride = encoder_stride
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return SwinConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ depths=self.depths,
+ num_heads=self.num_heads,
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ drop_path_rate=self.drop_path_rate,
+ hidden_act=self.hidden_act,
+ use_absolute_embeddings=self.use_absolute_embeddings,
+ path_norm=self.patch_norm,
+ layer_norm_eps=self.layer_norm_eps,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = SwinModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
+ expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = SwinForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values,
+ labels,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class SwinModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ SwinModel,
+ SwinForImageClassification,
+ SwinForMaskedImageModeling,
+ )
+ if is_torch_available()
+ else ()
+ )
+ fx_compatible = True
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = SwinModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=SwinConfig, embed_dim=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ # Swin does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ expected_num_attentions = len(self.model_tester.depths)
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ window_size_squared = config.window_size**2
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ else:
+ # also another +1 for reshaped_hidden_states
+ added_hidden_states = 2
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # Swin has a different seq_length
+ patch_size = to_2tuple(config.patch_size)
+
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
+ reshaped_hidden_states = (
+ reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ )
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ def test_hidden_states_output(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ image_size = to_2tuple(self.model_tester.image_size)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ def test_hidden_states_output_with_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = to_2tuple(self.model_tester.image_size)
+ patch_size = to_2tuple(config.patch_size)
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = SwinModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if "embeddings" not in name and param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+
+@require_vision
+@require_torch
+class SwinModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
+ if is_vision_available()
+ else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
+ feature_extractor = self.default_feature_extractor
+
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+ expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_tf_swin.py
similarity index 66%
rename from tests/swin/test_modeling_swin.py
rename to tests/models/swin/test_modeling_tf_swin.py
index 2147f578e73ea0..88323d7fd7a594 100644
--- a/tests/swin/test_modeling_swin.py
+++ b/tests/models/swin/test_modeling_tf_swin.py
@@ -12,26 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Testing suite for the PyTorch Swin model. """
+""" Testing suite for the TF 2.0 Swin model. """
+
-import copy
import inspect
import unittest
+import numpy as np
+
from transformers import SwinConfig
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
-from transformers.utils import cached_property, is_torch_available, is_vision_available
+from transformers.testing_utils import require_tf, require_vision, slow
+from transformers.utils import cached_property, is_tf_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
-if is_torch_available():
- import torch
- from torch import nn
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers.models.swin.modeling_tf_swin import (
+ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ to_2tuple,
+ )
- from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
- from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
if is_vision_available():
from PIL import Image
@@ -39,15 +46,7 @@
from transformers import AutoFeatureExtractor
-def _config_zero_init(config):
- configs_no_init = copy.deepcopy(config)
- for key in configs_no_init.__dict__.keys():
- if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
- setattr(configs_no_init, key, 1e-10)
- return configs_no_init
-
-
-class SwinModelTester:
+class TFSwinModelTester:
def __init__(
self,
parent,
@@ -74,7 +73,7 @@ def __init__(
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
- ):
+ ) -> None:
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
@@ -134,9 +133,7 @@ def get_config(self):
)
def create_and_check_model(self, config, pixel_values, labels):
- model = SwinModel(config=config)
- model.to(torch_device)
- model.eval()
+ model = TFSwinModel(config=config)
result = model(pixel_values)
expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
@@ -146,42 +143,37 @@ def create_and_check_model(self, config, pixel_values, labels):
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
- model = SwinForImageClassification(config)
- model.to(torch_device)
- model.eval()
+ model = TFSwinForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
- (
- config,
- pixel_values,
- labels,
- ) = config_and_inputs
+ config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
-@require_torch
-class SwinModelTest(ModelTesterMixin, unittest.TestCase):
+@require_tf
+class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
- SwinModel,
- SwinForImageClassification,
- SwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
)
- if is_torch_available()
+ if is_tf_available()
else ()
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
+ test_onnx = False
def setUp(self):
- self.model_tester = SwinModelTester(self)
+ self.model_tester = TFSwinModelTester(self)
self.config_tester = ConfigTester(self, config_class=SwinConfig, embed_dim=37)
def test_config(self):
@@ -200,8 +192,8 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ @unittest.skip(reason="Swin does not use inputs_embeds")
def test_inputs_embeds(self):
- # Swin does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -209,16 +201,16 @@ def test_model_common_attributes(self):
for model_class in self.all_model_classes:
model = model_class(config)
- self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_embeddings()
- self.assertTrue(x is None or isinstance(x, nn.Linear))
+ self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
- signature = inspect.signature(model.forward)
+ signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
@@ -234,10 +226,7 @@ def test_attention_outputs(self):
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
@@ -247,10 +236,7 @@ def test_attention_outputs(self):
config.output_attentions = True
window_size_squared = config.window_size**2
model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
@@ -264,10 +250,7 @@ def test_attention_outputs(self):
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
@@ -285,56 +268,73 @@ def test_attention_outputs(self):
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ hidden_states = outputs.hidden_states
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
- hidden_states = outputs.hidden_states
+ # Swin has a different seq_length
+ patch_size = to_2tuple(config.patch_size)
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- # Swin has a different seq_length
- image_size = to_2tuple(self.model_tester.image_size)
- patch_size = to_2tuple(self.model_tester.patch_size)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [num_patches, self.model_tester.embed_dim],
- )
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
- reshaped_hidden_states = outputs.reshaped_hidden_states
- self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+ reshaped_hidden_states = tf.reshape(reshaped_hidden_states[0], (batch_size, num_channels, height * width))
+ reshaped_hidden_states = tf.transpose(reshaped_hidden_states, (0, 2, 1))
- batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
- reshaped_hidden_states = (
- reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
- )
- self.assertListEqual(
- list(reshaped_hidden_states.shape[-2:]),
- [num_patches, self.model_tester.embed_dim],
- )
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ image_size = to_2tuple(self.model_tester.image_size)
+
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
- check_hidden_states_output(inputs_dict, config, model_class)
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ def test_inputs_requiring_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = to_2tuple(self.model_tester.image_size)
+ patch_size = to_2tuple(config.patch_size)
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -342,28 +342,14 @@ def test_for_image_classification(self):
@slow
def test_model_from_pretrained(self):
- for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
- model = SwinModel.from_pretrained(model_name)
+ for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TFSwinModel.from_pretrained(model_name)
self.assertIsNotNone(model)
- def test_initialization(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- configs_no_init = _config_zero_init(config)
- for model_class in self.all_model_classes:
- model = model_class(config=configs_no_init)
- for name, param in model.named_parameters():
- if "embeddings" not in name and param.requires_grad:
- self.assertIn(
- ((param.data.mean() * 1e9).round() / 1e9).item(),
- [0.0, 1.0],
- msg=f"Parameter {name} of model {model_class} seems not properly initialized",
- )
-
@require_vision
-@require_torch
-class SwinModelIntegrationTest(unittest.TestCase):
+@require_tf
+class TFSwinModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
@@ -374,18 +360,17 @@ def default_feature_extractor(self):
@slow
def test_inference_image_classification_head(self):
- model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
+ model = TFSwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
feature_extractor = self.default_feature_extractor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
- inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+ inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
- with torch.no_grad():
- outputs = model(**inputs)
+ outputs = model(inputs)
# verify the logits
- expected_shape = torch.Size((1, 1000))
+ expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
- self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+ expected_slice = tf.constant([-0.0948, -0.6454, -0.0921])
+ self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/tests/vit/__init__.py b/tests/models/t5/__init__.py
similarity index 100%
rename from tests/vit/__init__.py
rename to tests/models/t5/__init__.py
diff --git a/tests/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py
similarity index 52%
rename from tests/t5/test_modeling_flax_t5.py
rename to tests/models/t5/test_modeling_flax_t5.py
index f4d8ebbab1ae39..f3b2c166ed12b6 100644
--- a/tests/t5/test_modeling_flax_t5.py
+++ b/tests/models/t5/test_modeling_flax_t5.py
@@ -27,9 +27,9 @@
slow,
)
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
@@ -573,16 +573,208 @@ def test_summarization(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
tok = T5Tokenizer.from_pretrained("t5-base")
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says . all 150 on board were killed when germanwings flight 9525 crashed .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut"
+ " centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
dct = tok(
diff --git a/tests/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py
similarity index 64%
rename from tests/t5/test_modeling_t5.py
rename to tests/models/t5/test_modeling_t5.py
index 8380484b06065f..3ed5521a62d0b6 100644
--- a/tests/t5/test_modeling_t5.py
+++ b/tests/models/t5/test_modeling_t5.py
@@ -22,9 +22,9 @@
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -509,12 +509,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
- fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = True
test_model_parallel = True
is_encoder_decoder = True
+ # The small T5 model needs higher percentages for CPU/MP tests
+ model_split_percents = [0.8, 0.9]
def setUp(self):
self.model_tester = T5ModelTester(self)
@@ -539,6 +541,12 @@ def test_model_v1_1(self):
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
+ def test_config_and_model_silu_gated(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ config = config_and_inputs[0]
+ config.feed_forward_proj = "gated-silu"
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
@@ -654,6 +662,10 @@ def test_generate_with_head_masking(self):
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
+ @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
+ def test_disk_offload(self):
+ pass
+
class T5EncoderOnlyModelTester:
def __init__(
@@ -909,16 +921,208 @@ def test_summarization(self):
model = self.model
tok = self.tokenizer
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
+ " implement a rigorous inspection regime .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
use_task_specific_params(model, "summarization")
@@ -971,7 +1175,10 @@ def test_translation_en_to_fr(self):
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_fr")
- en_text = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '
+ en_text = (
+ ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
+ " countless generations of stars: the oldest stars are seen as blue dots. "
+ )
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
input_ids = input_ids.to(torch_device)
diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py
similarity index 59%
rename from tests/t5/test_modeling_tf_t5.py
rename to tests/models/t5/test_modeling_tf_t5.py
index d3008f017de764..5ad746e34fc877 100644
--- a/tests/t5/test_modeling_tf_t5.py
+++ b/tests/models/t5/test_modeling_tf_t5.py
@@ -19,8 +19,8 @@
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
@@ -295,6 +295,13 @@ def test_t5_decoder_model_past_with_attn_mask(self):
def test_t5_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
+
+ # `create_and_check_t5_decoder_model_past_large_inputs` has special inputs:
+ # (config, input_ids, decoder_input_ids, attention_mask)
+ # and we have to prepare it correctly here.
+ config, input_ids, input_mask, token_labels = config_and_inputs
+ config_and_inputs = (config, input_ids, None, input_mask)
+
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate_fast(self):
@@ -700,19 +707,211 @@ def test_summarization(self):
model = self.model
tok = T5Tokenizer.from_pretrained("t5-base")
- FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
+ FRANCE_ARTICLE = ( # @noqa
+ "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
+ " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
+ ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
+ ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
+ " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
+ " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
+ " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
+ " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
+ " their websites. The publications said that they watched the video, which was found by a source close to"
+ " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
+ ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
+ " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the"
+ ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
+ " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
+ " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
+ " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
+ ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
+ ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
+ " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
+ " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
+ " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
+ ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
+ ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
+ ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
+ ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
+ " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
+ ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
+ " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
+ " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
+ ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
+ ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
+ " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
+ " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
+ " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
+ " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
+ ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
+ " sharing the information and documents -- including training and medical records -- with public"
+ " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
+ " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
+ " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
+ " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
+ " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
+ " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
+ " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
+ " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
+ " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
+ " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
+ " the flight school during his training were among several developments as investigators continued to"
+ " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
+ " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
+ ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
+ " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
+ " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
+ " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
+ " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
+ " lose his pilot's license, a European government official briefed on the investigation told CNN on"
+ ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
+ " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
+ " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
+ " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
+ " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
+ " he had psychological issues, the European government official said. But no matter what details emerge"
+ " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
+ ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
+ " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
+ ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
+ " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
+ ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
+ " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
+ " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
+ " Amiel and Anna-Maja Rappard contributed to this report."
+ )
- SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
+ SHORTER_ARTICLE = (
+ "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
- IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
+ IRAN_ARTICLE = (
+ "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
+ " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
+ " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
+ " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
+ " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
+ " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
+ " the announcement of the new framework will likely result in more heat than light. It will not be helped"
+ " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
+ " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
+ " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
+ " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
+ " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
+ " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
+ " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
+ " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
+ " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
+ " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
+ " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
+ " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
+ " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
+ " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
+ " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
+ " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
+ " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
+ " warning that a deal might be killed by Congress or a future president). This of course is not the case."
+ " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
+ " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
+ " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
+ " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
+ " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
+ " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
+ " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
+ " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
+ " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
+ " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
+ " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
+ " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
+ ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
+ " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
+ " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
+ " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement"
+ " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
+ " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
+ " some insist that any agreement must address Iranian missile programs, human rights violations or support"
+ " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are"
+ " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in"
+ " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
+ " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
+ " fact-based, not based on questionable assertions or dubious assumptions."
+ )
- ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
+ ARTICLE_SUBWAY = (
+ "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
+ " year later, she got married again in Westchester County, but to a different man and without divorcing"
+ " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
+ ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
+ " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
+ ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
+ ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
+ " license application, according to court documents. Prosecutors said the marriages were part of an"
+ " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
+ " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
+ " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
+ " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
+ " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
+ " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
+ " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
+ " said the immigration scam involved some of her husbands, who filed for permanent residence status"
+ " shortly after the marriages. Any divorces happened only after such filings were approved. It was"
+ " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
+ " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
+ ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
+ " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
+ " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
+ " up to four years in prison. Her next court appearance is scheduled for May 18."
+ )
expected_summaries = [
- 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
- "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
- "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
- 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
+ " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
+ " magazine says .",
+ "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
+ " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
+ " court, Palestinians may be subject to counter-charges as well .",
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
+ " the debate that has already begun since the announcement of the new framework will likely result in more"
+ " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
+ " implement a rigorous inspection regime .",
+ "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
+ ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
+ " times, with nine of her marriages occurring between 1999 and 2002 .",
]
task_specific_config = getattr(model.config, "task_specific_params", {})
@@ -787,7 +986,10 @@ def test_translation_en_to_fr(self):
translation_config = task_specific_config.get("translation_en_to_fr", {})
model.config.update(translation_config)
- en_text = ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of countless generations of stars: the oldest stars are seen as blue dots. '
+ en_text = (
+ ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
+ " countless generations of stars: the oldest stars are seen as blue dots. "
+ )
new_truncated_translation = (
"Cette section d'images provenant de l'enregistrement infrarouge effectuƩ par le tƩlescope Spitzer montre "
diff --git a/tests/t5/test_tokenization_t5.py b/tests/models/t5/test_tokenization_t5.py
similarity index 95%
rename from tests/t5/test_tokenization_t5.py
rename to tests/models/t5/test_tokenization_t5.py
index 2deaa21f3ac338..1c0fde222cdb5f 100644
--- a/tests/t5/test_tokenization_t5.py
+++ b/tests/models/t5/test_tokenization_t5.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@@ -223,6 +223,9 @@ def test_outputs_not_longer_than_maxlen(self):
["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
+ # Since T5 does NOT have a max input length,
+ # this test should be changed to the following in Transformers v5:
+ # self.assertEqual(batch.input_ids.shape, (2, 8001))
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_eos_in_input(self):
@@ -361,6 +364,13 @@ def test_special_tokens_initialization_with_non_empty_additional_special_tokens(
),
)
+ # overwritten from `test_tokenization_common` since T5 has no max length
+ def test_pretrained_model_lists(self):
+ # We should have at least one default checkpoint for each tokenizer
+ # We should specify the max input length as well (used in some part to list the pretrained checkpoints)
+ self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
+ self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
+
@slow
def test_tokenizer_integration(self):
# fmt: off
diff --git a/tests/vit_mae/__init__.py b/tests/models/tapas/__init__.py
similarity index 100%
rename from tests/vit_mae/__init__.py
rename to tests/models/tapas/__init__.py
diff --git a/tests/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py
similarity index 99%
rename from tests/tapas/test_modeling_tapas.py
rename to tests/models/tapas/test_modeling_tapas.py
index 385af04dedadb4..b7b4af6e5a2ad5 100644
--- a/tests/tapas/test_modeling_tapas.py
+++ b/tests/models/tapas/test_modeling_tapas.py
@@ -32,11 +32,17 @@
is_torch_available,
)
from transformers.models.auto import get_values
-from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
+from transformers.testing_utils import (
+ require_scatter,
+ require_tensorflow_probability,
+ require_torch,
+ slow,
+ torch_device,
+)
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -499,6 +505,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+ @require_tensorflow_probability
+ def test_pt_tf_model_equivalence(self):
+ super().test_pt_tf_model_equivalence()
+
def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on:
diff --git a/tests/tapas/test_modeling_tf_tapas.py b/tests/models/tapas/test_modeling_tf_tapas.py
similarity index 99%
rename from tests/tapas/test_modeling_tf_tapas.py
rename to tests/models/tapas/test_modeling_tf_tapas.py
index 9e3cb63f70b546..bf5e8be370c775 100644
--- a/tests/tapas/test_modeling_tf_tapas.py
+++ b/tests/models/tapas/test_modeling_tf_tapas.py
@@ -37,8 +37,8 @@
from transformers.testing_utils import require_tensorflow_probability, require_tf, slow
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
@@ -498,6 +498,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+ @unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
+ def test_dataset_conversion(self):
+ pass
+
def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on:
diff --git a/tests/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py
similarity index 99%
rename from tests/tapas/test_tokenization_tapas.py
rename to tests/models/tapas/test_tokenization_tapas.py
index a5c6da2a41b950..f712f324f95489 100644
--- a/tests/tapas/test_tokenization_tapas.py
+++ b/tests/models/tapas/test_tokenization_tapas.py
@@ -36,12 +36,13 @@
is_pt_tf_cross_test,
require_pandas,
require_scatter,
+ require_tensorflow_probability,
require_tokenizers,
require_torch,
slow,
)
-from ..test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
+from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
@require_tokenizers
@@ -141,6 +142,10 @@ def get_input_output_texts(self, tokenizer):
output_text = "unwanted, running"
return input_text, output_text
+ @require_tensorflow_probability
+ def test_tf_encode_plus_sent_to_model(self):
+ super().test_tf_encode_plus_sent_to_model()
+
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
@@ -251,7 +256,7 @@ def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
vocab = {}
- for (i, token) in enumerate(vocab_tokens):
+ for i, token in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
diff --git a/tests/wav2vec2/__init__.py b/tests/models/tapex/__init__.py
similarity index 100%
rename from tests/wav2vec2/__init__.py
rename to tests/models/tapex/__init__.py
diff --git a/tests/tapex/test_tokenization_tapex.py b/tests/models/tapex/test_tokenization_tapex.py
similarity index 99%
rename from tests/tapex/test_tokenization_tapex.py
rename to tests/models/tapex/test_tokenization_tapex.py
index dd9f3d4bcf25d0..c959b780215b9b 100644
--- a/tests/tapex/test_tokenization_tapex.py
+++ b/tests/models/tapex/test_tokenization_tapex.py
@@ -27,7 +27,7 @@
from transformers.models.tapex.tokenization_tapex import VOCAB_FILES_NAMES
from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_pandas
diff --git a/tests/wav2vec2_phoneme/__init__.py b/tests/models/trajectory_transformer/__init__.py
similarity index 100%
rename from tests/wav2vec2_phoneme/__init__.py
rename to tests/models/trajectory_transformer/__init__.py
diff --git a/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py b/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py
new file mode 100644
index 00000000000000..7cf5c741a1f6fa
--- /dev/null
+++ b/tests/models/trajectory_transformer/test_modeling_trajectory_transformer.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch TrajectoryTransformer model. """
+
+
+import inspect
+import unittest
+
+import numpy as np
+
+from transformers import TrajectoryTransformerConfig, is_torch_available
+from transformers.testing_utils import require_torch, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import TrajectoryTransformerModel
+ from transformers.models.trajectory_transformer.modeling_trajectory_transformer import (
+ TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
+ )
+
+
+class TrajectoryTransformerModelTester:
+ def __init__(self, parent, batch_size=13, n_embd=128, action_dim=6, observation_dim=17, is_training=True):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.n_embd = n_embd
+ self.action_dim = action_dim
+ self.observation_dim = observation_dim
+ self.is_training = is_training
+ self.seq_length = self.action_dim + self.observation_dim + 1
+
+ def prepare_config_and_inputs(self):
+ trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(self.batch_size)]).to(
+ torch_device
+ )
+ attention_mask = random_attention_mask((self.batch_size, self.seq_length)).to(torch_device)
+ targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(self.batch_size)]).to(
+ torch_device
+ )
+
+ config = self.get_config()
+ return config, trajectories, attention_mask, targets
+
+ def get_config(self):
+ return TrajectoryTransformerConfig(
+ batch_size=self.batch_size,
+ n_embd=self.n_embd,
+ action_dim=self.action_dim,
+ observation_dim=self.observation_dim,
+ )
+
+ def create_and_check_model(self, config, input_dict):
+ model = TrajectoryTransformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(trajectories=input_dict["trajectories"], attention_mask=input_dict["attention_mask"])
+ result = model(
+ trajectories=input_dict["trajectories"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ self.parent.assertEqual(result.hidden_states[-1].shape, (self.batch_size, self.seq_length, self.n_embd))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (config, trajectories, attention_mask, targets) = config_and_inputs
+ inputs_dict = {"trajectories": trajectories, "attention_mask": attention_mask, "targets": targets}
+ return config, inputs_dict
+
+
+@require_torch
+class TrajectoryTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (TrajectoryTransformerModel,) if is_torch_available() else ()
+
+ # Ignoring of a failing test from GenerationTesterMixin, as the model does not use inputs_ids
+ test_generate_without_input_ids = False
+
+ # Ignoring of a failing tests from ModelTesterMixin, as the model does not implement these features
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_attention_outputs = False
+ test_hidden_states_output = False
+ test_inputs_embeds = False
+ test_model_common_attributes = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = TrajectoryTransformerModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=TrajectoryTransformerConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_conditional_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["trajectories"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ # # Input is 'trajectories' not 'input_ids'
+ def test_model_main_input_name(self):
+ model_signature = inspect.signature(getattr(TrajectoryTransformerModel, "forward"))
+ # The main input is the name of the argument after `self`
+ observed_main_input_name = list(model_signature.parameters.keys())[1]
+ self.assertEqual(TrajectoryTransformerModel.main_input_name, observed_main_input_name)
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = self.has_attentions
+
+ model = TrajectoryTransformerModel(config)
+ model.to(torch_device)
+
+ outputs = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ output = outputs[0]
+ hidden_states = outputs.hidden_states[0]
+ hidden_states.retain_grad()
+
+ if self.has_attentions:
+ attentions = outputs.attentions[0]
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+
+ if self.has_attentions:
+ self.assertIsNotNone(attentions.grad)
+
+ def test_training(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ model = TrajectoryTransformerModel(config)
+ model.to(torch_device)
+ model.train()
+ loss = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ ).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ if not self.model_tester.is_training:
+ return
+
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ model = TrajectoryTransformerModel(config)
+ model.gradient_checkpointing_enable()
+ model.to(torch_device)
+ model.train()
+ loss = model(
+ trajectories=input_dict["trajectories"],
+ attention_mask=input_dict["attention_mask"],
+ targets=input_dict["targets"],
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=False,
+ return_dict=True,
+ ).loss
+ loss.backward()
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = TrajectoryTransformerModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class TrajectoryTransformerModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_prediction(self):
+ batch_size = 1
+
+ config = TrajectoryTransformerConfig.from_pretrained("CarlCochet/trajectory-transformer-halfcheetah-medium-v2")
+ model = TrajectoryTransformerModel.from_pretrained(
+ "CarlCochet/trajectory-transformer-halfcheetah-medium-v2", config=config
+ )
+ model.to(torch_device)
+ model.eval()
+
+ seq_length = model.config.action_dim + model.config.observation_dim + 1
+
+ trajectories = torch.LongTensor(
+ [[3, 19, 20, 22, 9, 7, 23, 10, 18, 14, 13, 4, 17, 11, 5, 6, 15, 21, 2, 8, 1, 0, 12, 16]]
+ ).to(torch_device)
+ outputs = model(
+ trajectories=trajectories,
+ output_hidden_states=True,
+ output_attentions=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ output = outputs.logits
+
+ expected_shape = torch.Size((batch_size, seq_length, model.config.vocab_size + 1))
+ expected_slice = torch.tensor(
+ [[[-0.7193, -0.2532, -0.0898], [1.9429, 2.0434, 2.3975], [-3.3651, -2.8744, -2.4532]]]
+ ).to(torch_device)
+ output_slice = output[:, :3, :3]
+
+ self.assertEqual(output.shape, expected_shape)
+ self.assertTrue(torch.allclose(output_slice, expected_slice, atol=1e-4))
diff --git a/tests/wav2vec2_with_lm/__init__.py b/tests/models/transfo_xl/__init__.py
similarity index 100%
rename from tests/wav2vec2_with_lm/__init__.py
rename to tests/models/transfo_xl/__init__.py
diff --git a/tests/transfo_xl/test_modeling_tf_transfo_xl.py b/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
similarity index 97%
rename from tests/transfo_xl/test_modeling_tf_transfo_xl.py
rename to tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
index 87aca5097a4fa9..84e25d8716f5f8 100644
--- a/tests/transfo_xl/test_modeling_tf_transfo_xl.py
+++ b/tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
@@ -20,8 +20,8 @@
from transformers import TransfoXLConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
@@ -216,6 +216,10 @@ def test_model_from_pretrained(self):
model = TFTransfoXLModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ @unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
+ def test_dataset_conversion(self):
+ pass
+
@require_tf
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/transfo_xl/test_modeling_transfo_xl.py b/tests/models/transfo_xl/test_modeling_transfo_xl.py
similarity index 99%
rename from tests/transfo_xl/test_modeling_transfo_xl.py
rename to tests/models/transfo_xl/test_modeling_transfo_xl.py
index d4dbba448a9569..309811efb46512 100644
--- a/tests/transfo_xl/test_modeling_transfo_xl.py
+++ b/tests/models/transfo_xl/test_modeling_transfo_xl.py
@@ -20,9 +20,9 @@
from transformers import TransfoXLConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
diff --git a/tests/transfo_xl/test_tokenization_transfo_xl.py b/tests/models/transfo_xl/test_tokenization_transfo_xl.py
similarity index 98%
rename from tests/transfo_xl/test_tokenization_transfo_xl.py
rename to tests/models/transfo_xl/test_tokenization_transfo_xl.py
index 261fcf00445a44..3f7065c51b4739 100644
--- a/tests/transfo_xl/test_tokenization_transfo_xl.py
+++ b/tests/models/transfo_xl/test_tokenization_transfo_xl.py
@@ -19,7 +19,7 @@
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/wavlm/__init__.py b/tests/models/trocr/__init__.py
similarity index 100%
rename from tests/wavlm/__init__.py
rename to tests/models/trocr/__init__.py
diff --git a/tests/trocr/test_modeling_trocr.py b/tests/models/trocr/test_modeling_trocr.py
similarity index 96%
rename from tests/trocr/test_modeling_trocr.py
rename to tests/models/trocr/test_modeling_trocr.py
index b15b059f92990c..0c5e6f7ae8f9f5 100644
--- a/tests/trocr/test_modeling_trocr.py
+++ b/tests/models/trocr/test_modeling_trocr.py
@@ -19,9 +19,9 @@
from transformers import TrOCRConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -161,6 +161,7 @@ def prepare_config_and_inputs_for_common(self):
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_pruning = False
def setUp(self):
diff --git a/tests/xglm/__init__.py b/tests/models/unispeech/__init__.py
similarity index 100%
rename from tests/xglm/__init__.py
rename to tests/models/unispeech/__init__.py
diff --git a/tests/unispeech/test_modeling_unispeech.py b/tests/models/unispeech/test_modeling_unispeech.py
similarity index 99%
rename from tests/unispeech/test_modeling_unispeech.py
rename to tests/models/unispeech/test_modeling_unispeech.py
index 9a25237bf357c2..228b0dd175f86a 100644
--- a/tests/unispeech/test_modeling_unispeech.py
+++ b/tests/models/unispeech/test_modeling_unispeech.py
@@ -24,8 +24,8 @@
from transformers import UniSpeechConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -107,7 +107,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
diff --git a/tests/xlm/__init__.py b/tests/models/unispeech_sat/__init__.py
similarity index 100%
rename from tests/xlm/__init__.py
rename to tests/models/unispeech_sat/__init__.py
diff --git a/tests/unispeech_sat/test_modeling_unispeech_sat.py b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py
similarity index 99%
rename from tests/unispeech_sat/test_modeling_unispeech_sat.py
rename to tests/models/unispeech_sat/test_modeling_unispeech_sat.py
index da4359659a4ebf..6ac06e4db9be44 100644
--- a/tests/unispeech_sat/test_modeling_unispeech_sat.py
+++ b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py
@@ -24,8 +24,8 @@
from transformers import UniSpeechSatConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -121,7 +121,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
@@ -306,7 +306,7 @@ def check_xvector_training(self, config, *args):
model.freeze_base_model()
# use a longer sequence length to account for TDNN temporal downsampling
- input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length * 2], scale=1.0)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
diff --git a/tests/xlm_prophetnet/__init__.py b/tests/models/van/__init__.py
similarity index 100%
rename from tests/xlm_prophetnet/__init__.py
rename to tests/models/van/__init__.py
diff --git a/tests/van/test_modeling_van.py b/tests/models/van/test_modeling_van.py
similarity index 97%
rename from tests/van/test_modeling_van.py
rename to tests/models/van/test_modeling_van.py
index dff60fea38b7c5..6b6a672b9b4f24 100644
--- a/tests/van/test_modeling_van.py
+++ b/tests/models/van/test_modeling_van.py
@@ -23,8 +23,8 @@
from transformers.testing_utils import require_scipy, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_scipy_available, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_scipy_available():
@@ -144,6 +144,10 @@ def test_config(self):
def create_and_test_config_common_properties(self):
return
+ @unittest.skip(reason="Van does not output attentions")
+ def test_attention_outputs(self):
+ pass
+
@unittest.skip(reason="Van does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/xlm_roberta/__init__.py b/tests/models/vilt/__init__.py
similarity index 100%
rename from tests/xlm_roberta/__init__.py
rename to tests/models/vilt/__init__.py
diff --git a/tests/vilt/test_feature_extraction_vilt.py b/tests/models/vilt/test_feature_extraction_vilt.py
similarity index 98%
rename from tests/vilt/test_feature_extraction_vilt.py
rename to tests/models/vilt/test_feature_extraction_vilt.py
index 7c82e63eaf698d..62a9783c815a18 100644
--- a/tests/vilt/test_feature_extraction_vilt.py
+++ b/tests/models/vilt/test_feature_extraction_vilt.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py
similarity index 98%
rename from tests/vilt/test_modeling_vilt.py
rename to tests/models/vilt/test_modeling_vilt.py
index 2ddf9c3455d0e6..1a2f95d0e6cd94 100644
--- a/tests/vilt/test_modeling_vilt.py
+++ b/tests/models/vilt/test_modeling_vilt.py
@@ -24,8 +24,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -589,7 +589,10 @@ def test_inference_natural_language_visual_reasoning(self):
image1 = Image.open(dataset[0]["file"]).convert("RGB")
image2 = Image.open(dataset[1]["file"]).convert("RGB")
- text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ text = (
+ "The left image contains twice the number of dogs as the right image, and at least two dogs in total are"
+ " standing."
+ )
encoding_1 = processor(image1, text, return_tensors="pt")
encoding_2 = processor(image2, text, return_tensors="pt")
diff --git a/tests/xlm_roberta_xl/__init__.py b/tests/models/vision_encoder_decoder/__init__.py
similarity index 100%
rename from tests/xlm_roberta_xl/__init__.py
rename to tests/models/vision_encoder_decoder/__init__.py
diff --git a/tests/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py
similarity index 99%
rename from tests/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py
rename to tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py
index 163b8ddaa231f8..f874ad1c6337fd 100644
--- a/tests/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py
@@ -22,8 +22,8 @@
from transformers import is_flax_available, is_torch_available, is_vision_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_vision, slow, torch_device
+from ...test_modeling_flax_common import floats_tensor, ids_tensor
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
-from ..test_modeling_flax_common import floats_tensor, ids_tensor
from ..vit.test_modeling_flax_vit import FlaxViTModelTester
diff --git a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
similarity index 99%
rename from tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
rename to tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
index 158aa4e5f07607..9edbd3f802fb88 100644
--- a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
@@ -32,8 +32,8 @@
torch_device,
)
+from ...test_modeling_tf_common import floats_tensor, ids_tensor
from ..gpt2.test_modeling_tf_gpt2 import TFGPT2ModelTester
-from ..test_modeling_tf_common import floats_tensor, ids_tensor
from ..vit.test_modeling_tf_vit import TFViTModelTester
diff --git a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
similarity index 99%
rename from tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
rename to tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
index b867778ec9656f..f8ac8f1cdf1c36 100644
--- a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
@@ -23,11 +23,11 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
+from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..swin.test_modeling_swin import SwinModelTester
-from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from ..vit.test_modeling_vit import ViTModelTester
diff --git a/tests/xlnet/__init__.py b/tests/models/vision_text_dual_encoder/__init__.py
similarity index 100%
rename from tests/xlnet/__init__.py
rename to tests/models/vision_text_dual_encoder/__init__.py
diff --git a/tests/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py
similarity index 99%
rename from tests/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py
rename to tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py
index 27893a3d3ba081..cb476c128aa685 100644
--- a/tests/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py
+++ b/tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py
@@ -31,9 +31,9 @@
)
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
+from ...test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
from ..clip.test_modeling_flax_clip import FlaxCLIPVisionModelTester
-from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..vit.test_modeling_flax_vit import FlaxViTModelTester
diff --git a/tests/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
similarity index 99%
rename from tests/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
rename to tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
index a5fd8eb9113c85..18182047d66475 100644
--- a/tests/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
+++ b/tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
@@ -24,11 +24,11 @@
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import is_flax_available, is_torch_available, is_vision_available
+from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bert.test_modeling_bert import BertModelTester
from ..clip.test_modeling_clip import CLIPVisionModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..roberta.test_modeling_roberta import RobertaModelTester
-from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..vit.test_modeling_vit import ViTModelTester
diff --git a/tests/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
similarity index 100%
rename from tests/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
rename to tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
diff --git a/tests/yoso/__init__.py b/tests/models/visual_bert/__init__.py
similarity index 100%
rename from tests/yoso/__init__.py
rename to tests/models/visual_bert/__init__.py
diff --git a/tests/visual_bert/test_modeling_visual_bert.py b/tests/models/visual_bert/test_modeling_visual_bert.py
similarity index 99%
rename from tests/visual_bert/test_modeling_visual_bert.py
rename to tests/models/visual_bert/test_modeling_visual_bert.py
index e84b4d11a14d35..99db914072ccab 100644
--- a/tests/visual_bert/test_modeling_visual_bert.py
+++ b/tests/models/visual_bert/test_modeling_visual_bert.py
@@ -20,8 +20,8 @@
from transformers import VisualBertConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
diff --git a/tests/models/vit/__init__.py b/tests/models/vit/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/vit/test_feature_extraction_vit.py b/tests/models/vit/test_feature_extraction_vit.py
similarity index 98%
rename from tests/vit/test_feature_extraction_vit.py
rename to tests/models/vit/test_feature_extraction_vit.py
index df722d74bb08bc..2daf6452fff5c0 100644
--- a/tests/vit/test_feature_extraction_vit.py
+++ b/tests/models/vit/test_feature_extraction_vit.py
@@ -21,7 +21,7 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
-from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
diff --git a/tests/vit/test_modeling_flax_vit.py b/tests/models/vit/test_modeling_flax_vit.py
similarity index 63%
rename from tests/vit/test_modeling_flax_vit.py
rename to tests/models/vit/test_modeling_flax_vit.py
index 0af2123c905d4b..56fe28d41bafd8 100644
--- a/tests/vit/test_modeling_flax_vit.py
+++ b/tests/models/vit/test_modeling_flax_vit.py
@@ -20,8 +20,8 @@
from transformers import ViTConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
if is_flax_available():
@@ -67,9 +67,9 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
- # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -123,50 +123,6 @@ def setUp(self) -> None:
def test_config(self):
self.config_tester.run_common_tests()
- # We need to override this test because in ViT, the seq_len equals the number of patches + 1
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- seq_length = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_length, seq_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_length, seq_length],
- )
-
# We neeed to override this test because ViT's forward signature is different than text models.
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -180,7 +136,7 @@ def test_forward_signature(self):
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
- # We neeed to override this test because ViT expects pixel_values instead of input_ids
+ # We need to override this test because ViT expects pixel_values instead of input_ids
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -204,35 +160,6 @@ def model_jitted(pixel_values, **kwargs):
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
- # We need to override this test because in ViT, the seq_len equals the number of patches + 1
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
-
- seq_length = self.model_tester.expected_seq_length
-
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- hidden_states = outputs.hidden_states
-
- self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
diff --git a/tests/vit/test_modeling_tf_vit.py b/tests/models/vit/test_modeling_tf_vit.py
similarity index 57%
rename from tests/vit/test_modeling_tf_vit.py
rename to tests/models/vit/test_modeling_tf_vit.py
index 9ad64e82370144..096558091ac820 100644
--- a/tests/vit/test_modeling_tf_vit.py
+++ b/tests/models/vit/test_modeling_tf_vit.py
@@ -16,16 +16,14 @@
import inspect
-import os
-import tempfile
import unittest
from transformers import ViTConfig
-from transformers.testing_utils import require_tf, require_vision, slow, tooslow
+from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
@@ -80,9 +78,9 @@ def __init__(
self.initializer_range = initializer_range
self.scope = scope
- # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -114,18 +112,14 @@ def get_config(self):
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTModel(config=config)
result = model(pixel_values, training=False)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
- expected_seq_length = (image_size // self.patch_size) ** 2 + 1
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, expected_seq_length, self.hidden_size)
- )
+ seq_length = (image_size // self.patch_size) ** 2 + 1
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, seq_length, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@@ -166,12 +160,12 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="ViT does not use inputs_embeds")
def test_inputs_embeds(self):
- # ViT does not use inputs_embeds
pass
+ @unittest.skip(reason="ViT does not use inputs_embeds")
def test_graph_mode_with_inputs_embeds(self):
- # ViT does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -199,131 +193,6 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- # overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated
- # in a different way than in text models.
- @tooslow
- def test_saved_model_creation_extended(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.output_hidden_states = True
- config.output_attentions = True
-
- if hasattr(config, "use_cache"):
- config.use_cache = True
-
- # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
- model = model_class(config)
- num_out = len(model(class_inputs_dict))
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname, saved_model=True)
- saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
- model = tf.keras.models.load_model(saved_model_dir)
- outputs = model(class_inputs_dict)
-
- output_hidden_states = outputs["hidden_states"]
- output_attentions = outputs["attentions"]
-
- self.assertEqual(len(outputs), num_out)
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
-
- self.assertEqual(len(output_hidden_states), expected_num_layers)
- self.assertListEqual(
- list(output_hidden_states[0].shape[-2:]),
- [seq_len, self.model_tester.hidden_size],
- )
-
- self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(output_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
-
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- # ViT has a different seq_length
- seq_length = self.model_tester.expected_seq_length
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
diff --git a/tests/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py
similarity index 67%
rename from tests/vit/test_modeling_vit.py
rename to tests/models/vit/test_modeling_vit.py
index 117815fa6db496..bfca8bf5cb9aa5 100644
--- a/tests/vit/test_modeling_vit.py
+++ b/tests/models/vit/test_modeling_vit.py
@@ -22,8 +22,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
@@ -81,9 +81,9 @@ def __init__(
self.scope = scope
self.encoder_stride = encoder_stride
- # in ViT, the expected seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
- self.expected_seq_length = num_patches + 1
+ self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -118,9 +118,7 @@ def create_and_check_model(self, config, pixel_values, labels):
model.to(torch_device)
model.eval()
result = model(pixel_values)
- self.parent.assertEqual(
- result.last_hidden_state.shape, (self.batch_size, self.expected_seq_length, self.hidden_size)
- )
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@@ -157,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
+ fx_compatible = True
test_pruning = False
test_resize_embeddings = False
@@ -169,8 +168,8 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()
+ @unittest.skip(reason="ViT does not use inputs_embeds")
def test_inputs_embeds(self):
- # ViT does not use inputs_embeds
pass
def test_model_common_attributes(self):
@@ -198,93 +197,6 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
- def test_attention_outputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- seq_len = self.model_tester.expected_seq_length
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- self.assertEqual(out_len + 1, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, seq_len, seq_len],
- )
-
- def test_hidden_states_output(self):
- def check_hidden_states_output(inputs_dict, config, model_class):
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- hidden_states = outputs.hidden_states
-
- expected_num_layers = getattr(
- self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
- )
- self.assertEqual(len(hidden_states), expected_num_layers)
-
- self.assertListEqual(
- list(hidden_states[0].shape[-2:]),
- [self.model_tester.expected_seq_length, self.model_tester.hidden_size],
- )
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- inputs_dict["output_hidden_states"] = True
- check_hidden_states_output(inputs_dict, config, model_class)
-
- # check that output_hidden_states also work using config
- del inputs_dict["output_hidden_states"]
- config.output_hidden_states = True
-
- check_hidden_states_output(inputs_dict, config, model_class)
-
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
diff --git a/tests/models/vit_mae/__init__.py b/tests/models/vit_mae/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py
similarity index 98%
rename from tests/vit_mae/test_modeling_tf_vit_mae.py
rename to tests/models/vit_mae/test_modeling_tf_vit_mae.py
index 5a95f46350658a..cb54e29b80f701 100644
--- a/tests/vit_mae/test_modeling_tf_vit_mae.py
+++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py
@@ -30,8 +30,8 @@
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, require_vision, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
@@ -107,6 +107,10 @@ def get_config(self):
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
+ decoder_hidden_size=self.hidden_size,
+ decoder_num_hidden_layers=self.num_hidden_layers,
+ decoder_num_attention_heads=self.num_attention_heads,
+ decoder_intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
diff --git a/tests/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py
similarity index 99%
rename from tests/vit_mae/test_modeling_vit_mae.py
rename to tests/models/vit_mae/test_modeling_vit_mae.py
index fae72a8ad7be34..191984d82f55ba 100644
--- a/tests/vit_mae/test_modeling_vit_mae.py
+++ b/tests/models/vit_mae/test_modeling_vit_mae.py
@@ -26,8 +26,8 @@
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
diff --git a/tests/models/wav2vec2/__init__.py b/tests/models/wav2vec2/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/wav2vec2/test_feature_extraction_wav2vec2.py b/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
similarity index 99%
rename from tests/wav2vec2/test_feature_extraction_wav2vec2.py
rename to tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
index 67c4e050fdf18c..98cf2f1c495bc6 100644
--- a/tests/wav2vec2/test_feature_extraction_wav2vec2.py
+++ b/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
@@ -23,7 +23,7 @@
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import require_torch, slow
-from ..test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
global_rng = random.Random()
diff --git a/tests/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
similarity index 98%
rename from tests/wav2vec2/test_modeling_flax_wav2vec2.py
rename to tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
index f70bb319fc1590..b74e271c02d6fc 100644
--- a/tests/wav2vec2/test_modeling_flax_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
@@ -30,7 +30,7 @@
slow,
)
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
if is_flax_available():
@@ -117,7 +117,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config(
@@ -463,7 +463,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
similarity index 98%
rename from tests/wav2vec2/test_modeling_tf_wav2vec2.py
rename to tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
index 6da315e7898b53..323f44ba99fb4f 100644
--- a/tests/wav2vec2/test_modeling_tf_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
@@ -29,8 +29,8 @@
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
from transformers.utils import is_librosa_available, is_pyctcdecode_available
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
@@ -548,7 +548,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
similarity index 99%
rename from tests/wav2vec2/test_modeling_wav2vec2.py
rename to tests/models/wav2vec2/test_modeling_wav2vec2.py
index c1978a45b7a0f7..21f77b19a553ca 100644
--- a/tests/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -33,8 +33,8 @@
torch_device,
)
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -150,7 +150,7 @@ def __init__(
self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
@@ -1179,7 +1179,8 @@ def test_inference_ctc_robust_batched(self):
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
- "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
+ "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around"
+ " him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@@ -1461,8 +1462,11 @@ def test_phoneme_recognition(self):
EXPECTED_TRANSCRIPTIONS = [
"É m Ʀ n s É d t É Ć° É j uĖ n ÉŖ v É s s É aÉŖ É É” z ÉŖ s t",
- "s w É t k Ź v É d b ɹ iĖ É n z b ÉĖ d i t ɹ ÉŖ k l ÉŖ Å ÉŖ n t É Ć° É t aÉŖ t l oÉŖ n k l ÉĖ Īø Ć° Ʀ w Ź z Ć° ÉŖ oŹ n l i É” ÉĖɹ m É n t h iĖ w ÉĖɹ",
- "Ć° É k aÉŖ t É n h ÉŖ z tŹ É s t s t ÉŖ l d ɹ ÉŖ p ÉŖ Å b l Ź d Ć° ÉŖ eÉŖ k Ź v h ÉŖ z oŹ v É s t ɹ eÉŖ n d aÉŖ z iĖ v É n Ć° É s ÉĖɹ ɹ ÉŖ Å É É¹ iĖ n É É É¹ aŹ n d h ÉŖ m w ÉŖ Ć° É Īø aŹ z É n d z Ź v s p É k t eÉŖ ɾ É z w ÉĖ t ɹ ÉŖ v ÉŖ Ʀ l įµ» ɾ i z n ÉĖ t w ÉĖ Īø Īø ÉŖ Å k ÉŖ Å É b aŹ t",
+ "s w É t k Ź v É d b ɹ iĖ É n z b ÉĖ d i t ɹ ÉŖ k l ÉŖ Å ÉŖ n t É Ć° É t aÉŖ t l oÉŖ n k l ÉĖ Īø Ć° Ʀ w Ź z Ć° ÉŖ oŹ"
+ " n l i É” ÉĖɹ m É n t h iĖ w ÉĖɹ",
+ "Ć° É k aÉŖ t É n h ÉŖ z tŹ É s t s t ÉŖ l d ɹ ÉŖ p ÉŖ Å b l Ź d Ć° ÉŖ eÉŖ k Ź v h ÉŖ z oŹ v É s t ɹ eÉŖ n d aÉŖ z iĖ"
+ " v É n Ć° É s ÉĖɹ ɹ ÉŖ Å É É¹ iĖ n É É É¹ aŹ n d h ÉŖ m w ÉŖ Ć° É Īø aŹ z É n d z Ź v s p É k t eÉŖ ɾ É z w ÉĖ t ɹ"
+ " ÉŖ v ÉŖ Ʀ l įµ» ɾ i z n ÉĖ t w ÉĖ Īø Īø ÉŖ Å k ÉŖ Å É b aŹ t",
"h ÉŖ z ÉŖ n s t É n t v p Ʀ n ÉŖ k w Ź z f ÉĖ l oŹ d b aÉŖ É s m ÉĖ l Ź ÉĖɹ p b l oŹ h aÉŖ É n h ÉŖ z tŹ É s t",
]
# should correspond to =>:
diff --git a/tests/wav2vec2/test_processor_wav2vec2.py b/tests/models/wav2vec2/test_processor_wav2vec2.py
similarity index 100%
rename from tests/wav2vec2/test_processor_wav2vec2.py
rename to tests/models/wav2vec2/test_processor_wav2vec2.py
diff --git a/tests/wav2vec2/test_tokenization_wav2vec2.py b/tests/models/wav2vec2/test_tokenization_wav2vec2.py
similarity index 99%
rename from tests/wav2vec2/test_tokenization_wav2vec2.py
rename to tests/models/wav2vec2/test_tokenization_wav2vec2.py
index 48072fb81c22b2..4027e0cefc4d24 100644
--- a/tests/wav2vec2/test_tokenization_wav2vec2.py
+++ b/tests/models/wav2vec2/test_tokenization_wav2vec2.py
@@ -32,7 +32,7 @@
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput
from transformers.testing_utils import require_torch, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
global_rng = random.Random()
diff --git a/tests/models/wav2vec2_conformer/__init__.py b/tests/models/wav2vec2_conformer/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
new file mode 100644
index 00000000000000..cb2719a591b61c
--- /dev/null
+++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py
@@ -0,0 +1,939 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """
+
+import math
+import unittest
+
+import numpy as np
+from datasets import load_dataset
+
+from transformers import Wav2Vec2ConformerConfig, is_torch_available
+from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ )
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ Wav2Vec2ConformerGumbelVectorQuantizer,
+ _compute_mask_indices,
+ _sample_negative_indices,
+ )
+
+
+class Wav2Vec2ConformerModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=1024, # speech is longer
+ is_training=False,
+ hidden_size=16,
+ feat_extract_norm="group",
+ feat_extract_dropout=0.0,
+ feat_extract_activation="gelu",
+ conv_dim=(32, 32, 32),
+ conv_stride=(4, 4, 4),
+ conv_kernel=(8, 8, 8),
+ conv_bias=False,
+ num_conv_pos_embeddings=16,
+ num_conv_pos_embedding_groups=2,
+ num_hidden_layers=4,
+ num_attention_heads=2,
+ hidden_dropout_prob=0.1,
+ intermediate_size=20,
+ layer_norm_eps=1e-5,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ mask_time_prob=0.5,
+ mask_time_length=2,
+ vocab_size=32,
+ do_stable_layer_norm=False,
+ num_adapter_layers=1,
+ adapter_stride=2,
+ tdnn_dim=(32, 32),
+ tdnn_kernel=(5, 3),
+ tdnn_dilation=(1, 2),
+ xvector_output_dim=32,
+ position_embeddings_type="relative",
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.hidden_size = hidden_size
+ self.feat_extract_norm = feat_extract_norm
+ self.feat_extract_dropout = feat_extract_dropout
+ self.feat_extract_activation = feat_extract_activation
+ self.conv_dim = conv_dim
+ self.conv_stride = conv_stride
+ self.conv_kernel = conv_kernel
+ self.conv_bias = conv_bias
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.intermediate_size = intermediate_size
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.vocab_size = vocab_size
+ self.do_stable_layer_norm = do_stable_layer_norm
+ self.num_adapter_layers = num_adapter_layers
+ self.adapter_stride = adapter_stride
+ self.mask_time_prob = mask_time_prob
+ self.mask_time_length = mask_time_length
+ self.scope = scope
+ self.tdnn_dim = tdnn_dim
+ self.tdnn_kernel = tdnn_kernel
+ self.tdnn_dilation = tdnn_dilation
+ self.xvector_output_dim = xvector_output_dim
+ self.position_embeddings_type = position_embeddings_type
+
+ output_seq_length = self.seq_length
+ for kernel, stride in zip(self.conv_kernel, self.conv_stride):
+ output_seq_length = (output_seq_length - (kernel - 1)) / stride
+ self.output_seq_length = int(math.ceil(output_seq_length))
+ self.encoder_seq_length = self.output_seq_length
+
+ self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
+
+ def prepare_config_and_inputs(self, position_embeddings_type="relative"):
+ input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config(position_embeddings_type=position_embeddings_type)
+
+ return config, input_values, attention_mask
+
+ def get_config(self, position_embeddings_type="relative"):
+ return Wav2Vec2ConformerConfig(
+ hidden_size=self.hidden_size,
+ feat_extract_norm=self.feat_extract_norm,
+ feat_extract_dropout=self.feat_extract_dropout,
+ feat_extract_activation=self.feat_extract_activation,
+ conv_dim=self.conv_dim,
+ conv_stride=self.conv_stride,
+ conv_kernel=self.conv_kernel,
+ conv_bias=self.conv_bias,
+ mask_time_prob=self.mask_time_prob,
+ mask_time_length=self.mask_time_length,
+ num_conv_pos_embeddings=self.num_conv_pos_embeddings,
+ num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ intermediate_size=self.intermediate_size,
+ layer_norm_eps=self.layer_norm_eps,
+ do_stable_layer_norm=self.do_stable_layer_norm,
+ hidden_act=self.hidden_act,
+ initializer_range=self.initializer_range,
+ vocab_size=self.vocab_size,
+ num_adapter_layers=self.num_adapter_layers,
+ adapter_stride=self.adapter_stride,
+ tdnn_dim=self.tdnn_dim,
+ tdnn_kernel=self.tdnn_kernel,
+ tdnn_dilation=self.tdnn_dilation,
+ xvector_output_dim=self.xvector_output_dim,
+ position_embeddings_type=position_embeddings_type,
+ )
+
+ def create_and_check_model(self, config, input_values, attention_mask):
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
+ )
+
+ def create_and_check_model_with_adapter_for_ctc(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 2 * config.hidden_size
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.logits.shape, (self.batch_size, self.adapter_output_seq_length, self.vocab_size)
+ )
+
+ def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
+ config.add_adapter = True
+ config.output_hidden_size = 8
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_values, attention_mask=attention_mask)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
+ )
+
+ def create_and_check_batch_inference(self, config, input_values, *args):
+ # test does not pass for models making use of `group_norm`
+ # check: https://github.com/pytorch/fairseq/issues/3227
+ model = Wav2Vec2ConformerModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0.0
+
+ batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
+
+ for i in range(input_values.shape[0]):
+ input_slice = input_values[i : i + 1, : input_lengths[i]]
+ output = model(input_slice).last_hidden_state
+
+ batch_output = batch_outputs[i : i + 1, : output.shape[1]]
+ self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
+
+ def check_ctc_loss(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ model.config.ctc_loss_reduction = "sum"
+ sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ model.config.ctc_loss_reduction = "mean"
+ mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(sum_loss, float))
+ self.parent.assertTrue(isinstance(mean_loss, float))
+
+ def check_seq_classifier_loss(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForSequenceClassification(config=config)
+ model.to(torch_device)
+
+ # make sure that dropout is disabled
+ model.eval()
+
+ input_values = input_values[:3]
+ attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+ attention_mask[i, input_lengths[i] :] = 0
+
+ masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
+ unmasked_loss = model(input_values, labels=labels).loss.item()
+
+ self.parent.assertTrue(isinstance(masked_loss, float))
+ self.parent.assertTrue(isinstance(unmasked_loss, float))
+ self.parent.assertTrue(masked_loss != unmasked_loss)
+
+ def check_ctc_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForCTC(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze feature encoder
+ model.freeze_feature_encoder()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ if max_length_labels[i] < labels.shape[-1]:
+ # it's important that we make sure that target lenghts are at least
+ # one shorter than logit lenghts to prevent -inf
+ labels[i, max_length_labels[i] - 1 :] = -100
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_seq_classifier_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForSequenceClassification(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_xvector_training(self, config, input_values, *args):
+ config.ctc_zero_infinity = True
+ model = Wav2Vec2ConformerForXVector(config=config)
+ model.to(torch_device)
+ model.train()
+
+ # freeze everything but the classification head
+ model.freeze_base_model()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
+
+ # pad input
+ for i in range(len(input_lengths)):
+ input_values[i, input_lengths[i] :] = 0.0
+
+ loss = model(input_values, labels=labels).loss
+ self.parent.assertFalse(torch.isinf(loss).item())
+
+ loss.backward()
+
+ def check_labels_out_of_vocab(self, config, input_values, *args):
+ model = Wav2Vec2ConformerForCTC(config)
+ model.to(torch_device)
+ model.train()
+
+ input_values = input_values[:3]
+
+ input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
+ max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
+ labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100)
+
+ with self.parent.assertRaises(ValueError):
+ model(input_values, labels=labels)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_values, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForXVector,
+ )
+ if is_torch_available()
+ else ()
+ )
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = Wav2Vec2ConformerModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Wav2Vec2ConformerConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_relative(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_rotary(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_no_rel_pos(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type=None)
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_with_adapter(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
+
+ def test_model_with_adapter_for_ctc(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_for_ctc(*config_and_inputs)
+
+ def test_model_with_adapter_proj_dim(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
+
+ def test_ctc_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_loss(*config_and_inputs)
+
+ def test_seq_classifier_loss_inference(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_loss(*config_and_inputs)
+
+ def test_ctc_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_ctc_training(*config_and_inputs)
+
+ def test_seq_classifier_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_seq_classifier_training(*config_and_inputs)
+
+ def test_xvector_train(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_xvector_training(*config_and_inputs)
+
+ def test_labels_out_of_vocab(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
+
+ # Wav2Vec2Conformer has no inputs_embeds
+ def test_inputs_embeds(self):
+ pass
+
+ # `input_ids` is renamed to `input_values`
+ def test_forward_signature(self):
+ pass
+
+ # Wav2Vec2Conformer cannot resize token embeddings
+ # since it has no tokens embeddings
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ # Wav2Vec2Conformer has no inputs_embeds
+ # and thus the `get_input_embeddings` fn
+ # is not implemented
+ def test_model_common_attributes(self):
+ pass
+
+ @is_pt_flax_cross_test
+ # non-robust architecture does not exist in Flax
+ def test_equivalence_flax_to_pt(self):
+ pass
+
+ @is_pt_flax_cross_test
+ # non-robust architecture does not exist in Flax
+ def test_equivalence_pt_to_flax(self):
+ pass
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ # set layer drop to 0
+ model.config.layerdrop = 0.0
+
+ input_values = inputs_dict["input_values"]
+
+ input_lengths = torch.tensor(
+ [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
+ )
+ output_lengths = model._get_feat_extract_output_lengths(input_lengths)
+
+ labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
+ inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
+ inputs_dict["labels"] = labels
+
+ outputs = model(**inputs_dict)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = [
+ "conv.weight",
+ "masked_spec_embed",
+ "codevectors",
+ "quantizer.weight_proj.weight",
+ "project_hid.weight",
+ "project_hid.bias",
+ "project_q.weight",
+ "project_q.bias",
+ "pos_bias_v",
+ "pos_bias_u",
+ "pointwise_conv1",
+ "pointwise_conv2",
+ "feature_projection.projection.weight",
+ "feature_projection.projection.bias",
+ "objective.weight",
+ ]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # overwrite from test_modeling_common
+ def _mock_init_weights(self, module):
+ if hasattr(module, "weight") and module.weight is not None:
+ module.weight.data.fill_(3)
+ if hasattr(module, "weight_g") and module.weight_g is not None:
+ module.weight_g.data.fill_(3)
+ if hasattr(module, "weight_v") and module.weight_v is not None:
+ module.weight_v.data.fill_(3)
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.fill_(3)
+ if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
+ module.pos_bias_u.data.fill_(3)
+ if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
+ module.pos_bias_v.data.fill_(3)
+ if hasattr(module, "codevectors") and module.codevectors is not None:
+ module.codevectors.data.fill_(3)
+ if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
+ module.masked_spec_embed.data.fill_(3)
+
+ def test_mask_feature_prob_ctc(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", mask_feature_prob=0.2, mask_feature_length=2
+ )
+ model.to(torch_device).train()
+ processor = Wav2Vec2Processor.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
+ )
+
+ batch_duration_in_seconds = [1, 3, 2, 6]
+ input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
+
+ batch = processor(
+ input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
+ )
+
+ logits = model(
+ input_values=batch["input_values"].to(torch_device),
+ attention_mask=batch["attention_mask"].to(torch_device),
+ ).logits
+
+ self.assertEqual(logits.shape, (4, 1498, 32))
+
+ def test_mask_time_prob_ctc(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", mask_time_prob=0.2, mask_time_length=2
+ )
+ model.to(torch_device).train()
+ processor = Wav2Vec2Processor.from_pretrained(
+ "hf-internal-testing/tiny-random-wav2vec2-conformer", return_attention_mask=True
+ )
+
+ batch_duration_in_seconds = [1, 3, 2, 6]
+ input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
+
+ batch = processor(
+ input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
+ )
+
+ logits = model(
+ input_values=batch["input_values"].to(torch_device),
+ attention_mask=batch["attention_mask"].to(torch_device),
+ ).logits
+
+ self.assertEqual(logits.shape, (4, 1498, 32))
+
+ @unittest.skip(reason="Feed forward chunking is not implemented")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = Wav2Vec2ConformerModel.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class Wav2Vec2ConformerUtilsTest(unittest.TestCase):
+ def test_compute_mask_indices(self):
+ batch_size = 4
+ sequence_length = 60
+ mask_prob = 0.5
+ mask_length = 1
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
+
+ def test_compute_mask_indices_low_prob(self):
+ # with these settings num_masked_spans=0.5, which means probabilistic rounding
+ # ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
+ # the other 5 out of 10, cases num_masked_spans=1
+ n_trials = 100
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ count_dimensions_masked = 0
+ count_dimensions_not_masked = 0
+
+ for _ in range(n_trials):
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ num_masks = torch.sum(mask).item()
+
+ if num_masks > 0:
+ count_dimensions_masked += 1
+ else:
+ count_dimensions_not_masked += 1
+
+ # as we test for at least 10 masked dimension and at least
+ # 10 non-masked dimension, this test could fail with probability:
+ # P(100 coin flips, at most 9 heads) = 1.66e-18
+ self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
+ self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
+
+ def test_compute_mask_indices_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ def test_compute_mask_indices_attn_mask_overlap(self):
+ batch_size = 4
+ sequence_length = 80
+ mask_prob = 0.5
+ mask_length = 4
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ attention_mask[:2, sequence_length // 2 :] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
+ )
+ mask = torch.from_numpy(mask).to(torch_device)
+
+ for batch_sum in mask.sum(axis=-1):
+ self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
+
+ self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
+
+ def test_compute_mask_indices_short_audio(self):
+ batch_size = 4
+ sequence_length = 100
+ mask_prob = 0.05
+ mask_length = 10
+
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ # force one example to be heavily padded
+ attention_mask[0, 5:] = 0
+
+ mask = _compute_mask_indices(
+ (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
+ )
+
+ # make sure that non-padded examples cannot be padded
+ self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
+
+ def test_compute_perplexity(self):
+ probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
+
+ ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs)
+ self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
+
+ # mask half of the input
+ mask = torch.ones((2,), device=torch_device, dtype=torch.bool)
+ mask[0] = 0
+
+ ppl = Wav2Vec2ConformerGumbelVectorQuantizer._compute_perplexity(probs, mask)
+ self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
+
+ def test_sample_negatives(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+ def test_sample_negatives_with_mask(self):
+ batch_size = 2
+ sequence_length = 10
+ hidden_size = 4
+ num_negatives = 3
+
+ # second half of last input tensor is padded
+ mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
+ mask[-1, sequence_length // 2 :] = 0
+
+ features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
+ sequence_length, hidden_size
+ ) # each value in vector consits of same value
+ features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
+
+ # replace masked feature vectors with -100 to test that those are not sampled
+ features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
+
+ # sample negative indices
+ sampled_negative_indices = _sample_negative_indices(
+ (batch_size, sequence_length), num_negatives, mask.cpu().numpy()
+ )
+ sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
+ negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
+ negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
+
+ self.assertTrue((negatives >= 0).all().item())
+
+ self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
+
+ # make sure no negatively sampled vector is actually a positive one
+ for negative in negatives:
+ self.assertTrue(((negative - features) == 0).sum() == 0.0)
+
+ # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
+ self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
+
+
+@require_torch
+@slow
+class Wav2Vec2ConformerModelIntegrationTest(unittest.TestCase):
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").filter(lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)])
+ speech_samples = speech_samples[:num_samples]["audio"]
+
+ return [x["array"] for x in speech_samples]
+
+ def test_inference_ctc_normal_batched_rel_pos(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large-960h-ft")
+ model.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-conformer-rel-pos-large-960h-ft", do_lower_case=True
+ )
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_values = inputs.input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight loincloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_ctc_normal_batched_rope(self):
+ model = Wav2Vec2ConformerForCTC.from_pretrained("facebook/wav2vec2-conformer-rope-large-960h-ft")
+ model.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-conformer-rope-large-960h-ft", do_lower_case=True
+ )
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True)
+
+ input_values = inputs.input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(input_values).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ def test_inference_pretrained(self):
+ model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ model.to(torch_device)
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
+ "facebook/wav2vec2-conformer-rel-pos-large", return_attention_mask=True
+ )
+ input_speech = self._load_datasamples(2)
+
+ inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
+
+ batch_size = inputs_dict["input_values"].shape[0]
+ feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
+
+ features_shape = (batch_size, feature_seq_length)
+
+ torch.manual_seed(0)
+ mask_time_indices = _compute_mask_indices(
+ features_shape,
+ model.config.mask_time_prob,
+ model.config.mask_time_length,
+ min_masks=2,
+ )
+ mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
+
+ with torch.no_grad():
+ outputs = model(
+ inputs_dict.input_values.to(torch_device),
+ attention_mask=inputs_dict.attention_mask.to(torch_device),
+ mask_time_indices=mask_time_indices,
+ )
+
+ # compute cosine similarity
+ cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ # retrieve cosine sim of masked features
+ cosine_sim_masked = cosine_sim[mask_time_indices]
+
+ # ... now compare to randomly initialized model
+
+ config = Wav2Vec2ConformerConfig.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ model_rand = Wav2Vec2ConformerForPreTraining(config).to(torch_device).eval()
+
+ with torch.no_grad():
+ outputs_rand = model_rand(
+ inputs_dict.input_values.to(torch_device),
+ attention_mask=inputs_dict.attention_mask.to(torch_device),
+ mask_time_indices=mask_time_indices,
+ )
+
+ # compute cosine similarity
+ cosine_sim_rand = torch.cosine_similarity(
+ outputs_rand.projected_states, outputs_rand.projected_quantized_states, dim=-1
+ )
+
+ # retrieve cosine sim of masked features
+ cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
+
+ # a pretrained wav2vec2_conformer model has learned to predict the quantized latent states
+ # => the cosine similarity between quantized states and predicted states > 0.5
+ # a random wav2vec2_conformer model has not learned to predict the quantized latent states
+ # => the cosine similarity between quantized states and predicted states is very likely < 0.1
+ self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
diff --git a/tests/models/wav2vec2_phoneme/__init__.py b/tests/models/wav2vec2_phoneme/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py b/tests/models/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
similarity index 99%
rename from tests/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
rename to tests/models/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
index 577471c0fa9e27..0411a863bc723a 100644
--- a/tests/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
+++ b/tests/models/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
@@ -23,7 +23,7 @@
from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput
from transformers.testing_utils import require_phonemizer
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
@require_phonemizer
diff --git a/tests/models/wav2vec2_with_lm/__init__.py b/tests/models/wav2vec2_with_lm/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
similarity index 100%
rename from tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
rename to tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
diff --git a/tests/models/wavlm/__init__.py b/tests/models/wavlm/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py
similarity index 99%
rename from tests/wavlm/test_modeling_wavlm.py
rename to tests/models/wavlm/test_modeling_wavlm.py
index 937325e721ce85..297d207af3e56a 100644
--- a/tests/wavlm/test_modeling_wavlm.py
+++ b/tests/models/wavlm/test_modeling_wavlm.py
@@ -23,8 +23,8 @@
from transformers import WavLMConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torchaudio, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import (
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@@ -114,7 +114,7 @@ def __init__(
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ input_values = floats_tensor([self.batch_size, self.seq_length], scale=1.0)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
diff --git a/tests/models/xglm/__init__.py b/tests/models/xglm/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xglm/test_modeling_flax_xglm.py b/tests/models/xglm/test_modeling_flax_xglm.py
similarity index 98%
rename from tests/xglm/test_modeling_flax_xglm.py
rename to tests/models/xglm/test_modeling_flax_xglm.py
index 45399d96242171..f20a1b378f5ff0 100644
--- a/tests/xglm/test_modeling_flax_xglm.py
+++ b/tests/models/xglm/test_modeling_flax_xglm.py
@@ -21,8 +21,8 @@
from transformers import XGLMConfig, XGLMTokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_sentencepiece, slow
-from ..generation.test_generation_flax_utils import FlaxGenerationTesterMixin
-from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
diff --git a/tests/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py
similarity index 79%
rename from tests/xglm/test_modeling_xglm.py
rename to tests/models/xglm/test_modeling_xglm.py
index 1f80165a84cf82..f4da4994266d27 100644
--- a/tests/xglm/test_modeling_xglm.py
+++ b/tests/models/xglm/test_modeling_xglm.py
@@ -13,17 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import math
+import os
+import pickle
+import tempfile
import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import is_torch_fx_available
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
if is_torch_available():
@@ -31,6 +40,9 @@
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
class XGLMModelTester:
def __init__(
@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
+ fx_compatible = True
test_missing_keys = False
test_pruning = False
@@ -337,6 +350,112 @@ def test_xglm_weight_initialization(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "input_features",
+ ]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = [
+ "input_ids",
+ "attention_mask",
+ "token_type_ids",
+ "pixel_values",
+ "bbox",
+ "input_features",
+ ]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
@slow
def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
diff --git a/tests/xglm/test_tokenization_xglm.py b/tests/models/xglm/test_tokenization_xglm.py
similarity index 94%
rename from tests/xglm/test_tokenization_xglm.py
rename to tests/models/xglm/test_tokenization_xglm.py
index f7b270858252fa..05259ffaf9a335 100644
--- a/tests/xglm/test_tokenization_xglm.py
+++ b/tests/models/xglm/test_tokenization_xglm.py
@@ -13,20 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import pickle
import shutil
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, XGLMTokenizer, XGLMTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -180,7 +179,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to unk, such as saoneuhaoesuth"
+ )
# fmt: off
original_tokenizer_encodings = [2, 1018, 67, 11, 1988, 2617, 5631, 278, 11, 3407, 48, 71630, 28085, 4, 3234, 157, 13, 6, 5, 6, 4, 3526, 768, 15, 659, 57, 298, 3983, 864, 129, 21, 6, 5, 13675, 377, 652, 7580, 10341, 155, 2817, 422, 1666, 7, 1674, 53, 113, 202277, 17892, 33, 60, 87, 4, 3234, 157, 61, 2667, 52376, 19, 88, 23, 735]
# fmt: on
diff --git a/tests/models/xlm/__init__.py b/tests/models/xlm/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xlm/test_modeling_tf_xlm.py b/tests/models/xlm/test_modeling_tf_xlm.py
similarity index 98%
rename from tests/xlm/test_modeling_tf_xlm.py
rename to tests/models/xlm/test_modeling_tf_xlm.py
index 412a8430ad6d6d..00e77cee64ba89 100644
--- a/tests/xlm/test_modeling_tf_xlm.py
+++ b/tests/models/xlm/test_modeling_tf_xlm.py
@@ -19,8 +19,8 @@
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/xlm/test_modeling_xlm.py b/tests/models/xlm/test_modeling_xlm.py
similarity index 98%
rename from tests/xlm/test_modeling_xlm.py
rename to tests/models/xlm/test_modeling_xlm.py
index f336221072960e..8f56ed8472ea81 100644
--- a/tests/xlm/test_modeling_xlm.py
+++ b/tests/models/xlm/test_modeling_xlm.py
@@ -18,9 +18,9 @@
from transformers import XLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/xlm/test_tokenization_xlm.py b/tests/models/xlm/test_tokenization_xlm.py
similarity index 98%
rename from tests/xlm/test_tokenization_xlm.py
rename to tests/models/xlm/test_tokenization_xlm.py
index bd056b69d43091..adb4835eda4070 100644
--- a/tests/xlm/test_tokenization_xlm.py
+++ b/tests/models/xlm/test_tokenization_xlm.py
@@ -21,7 +21,7 @@
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES, XLMTokenizer
from transformers.testing_utils import slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
diff --git a/tests/models/xlm_prophetnet/__init__.py b/tests/models/xlm_prophetnet/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xlm_prophetnet/test_modeling_xlm_prophetnet.py b/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
similarity index 82%
rename from tests/xlm_prophetnet/test_modeling_xlm_prophetnet.py
rename to tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
index 51e8502b9bd5ac..5dec186bc7b9ce 100644
--- a/tests/xlm_prophetnet/test_modeling_xlm_prophetnet.py
+++ b/tests/models/xlm_prophetnet/test_modeling_xlm_prophetnet.py
@@ -102,8 +102,18 @@ def test_xprophetnet_ntg_inference(self):
tokenizer = XLMProphetNetTokenizer.from_pretrained("microsoft/xprophetnet-large-wiki100-cased-xglue-ntg")
- EN_SENTENCE = "Microsoft Corporation intends to officially end free support for the Windows 7 operating system after January 14, 2020, according to the official portal of the organization. From that day, users of this system will not be able to receive security updates, which could make their computers vulnerable to cyber attacks."
- RU_SENTENCE = "Š¾ŃŠæŠ¾ŃŠ°ŃŠøŃ Microsoft Š½Š°Š¼ŠµŃŠµŠ½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾ ŠæŃŠµŠŗŃŠ°ŃŠøŃŃ Š±ŠµŃŠæŠ»Š°ŃŠ½ŃŃ ŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ Š¾ŠæŠµŃŠ°ŃŠøŠ¾Š½Š½Š¾Š¹ ŃŠøŃŃŠµŠ¼Ń Windows 7 ŠæŠ¾ŃŠ»Šµ 14 ŃŠ½Š²Š°ŃŃ 2020 Š³Š¾Š“Š°, ŃŠ¾Š¾Š±ŃŠ°ŠµŃŃŃ Š½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾Š¼ ŠæŠ¾ŃŃŠ°Š»Šµ Š¾ŃŠ³Š°Š½ŠøŠ·Š°ŃŠøŠø . Š” ŃŠŗŠ°Š·Š°Š½Š½Š¾Š³Š¾ Š“Š½Ń ŠæŠ¾Š»ŃŠ·Š¾Š²Š°ŃŠµŠ»Šø ŃŃŠ¾Š¹ ŃŠøŃŃŠµŠ¼Ń Š½Šµ ŃŠ¼Š¾Š³ŃŃ ŠæŠ¾Š»ŃŃŠ°ŃŃ Š¾Š±Š½Š¾Š²Š»ŠµŠ½ŠøŃ Š±ŠµŠ·Š¾ŠæŠ°ŃŠ½Š¾ŃŃŠø, ŠøŠ·-Š·Š° ŃŠµŠ³Š¾ ŠøŃ
ŠŗŠ¾Š¼ŠæŃŃŃŠµŃŃ Š¼Š¾Š³ŃŃ ŃŃŠ°ŃŃ ŃŃŠ·Š²ŠøŠ¼ŃŠ¼Šø Šŗ ŠŗŠøŠ±ŠµŃŠ°ŃŠ°ŠŗŠ°Š¼."
+ EN_SENTENCE = (
+ "Microsoft Corporation intends to officially end free support for the Windows 7 operating system after"
+ " January 14, 2020, according to the official portal of the organization. From that day, users of this"
+ " system will not be able to receive security updates, which could make their computers vulnerable to"
+ " cyber attacks."
+ )
+ RU_SENTENCE = (
+ "Š¾ŃŠæŠ¾ŃŠ°ŃŠøŃ Microsoft Š½Š°Š¼ŠµŃŠµŠ½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾ ŠæŃŠµŠŗŃŠ°ŃŠøŃŃ Š±ŠµŃŠæŠ»Š°ŃŠ½ŃŃ ŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ Š¾ŠæŠµŃŠ°ŃŠøŠ¾Š½Š½Š¾Š¹ ŃŠøŃŃŠµŠ¼Ń Windows 7"
+ " ŠæŠ¾ŃŠ»Šµ 14 ŃŠ½Š²Š°ŃŃ 2020 Š³Š¾Š“Š°, ŃŠ¾Š¾Š±ŃŠ°ŠµŃŃŃ Š½Š° Š¾ŃŠøŃŠøŠ°Š»ŃŠ½Š¾Š¼ ŠæŠ¾ŃŃŠ°Š»Šµ Š¾ŃŠ³Š°Š½ŠøŠ·Š°ŃŠøŠø . Š” ŃŠŗŠ°Š·Š°Š½Š½Š¾Š³Š¾ Š“Š½Ń ŠæŠ¾Š»ŃŠ·Š¾Š²Š°ŃŠµŠ»Šø"
+ " ŃŃŠ¾Š¹ ŃŠøŃŃŠµŠ¼Ń Š½Šµ ŃŠ¼Š¾Š³ŃŃ ŠæŠ¾Š»ŃŃŠ°ŃŃ Š¾Š±Š½Š¾Š²Š»ŠµŠ½ŠøŃ Š±ŠµŠ·Š¾ŠæŠ°ŃŠ½Š¾ŃŃŠø, ŠøŠ·-Š·Š° ŃŠµŠ³Š¾ ŠøŃ
ŠŗŠ¾Š¼ŠæŃŃŃŠµŃŃ Š¼Š¾Š³ŃŃ ŃŃŠ°ŃŃ ŃŃŠ·Š²ŠøŠ¼ŃŠ¼Šø"
+ " Šŗ ŠŗŠøŠ±ŠµŃŠ°ŃŠ°ŠŗŠ°Š¼."
+ )
ZH_SENTENCE = (
"ę ¹ę®čÆ„ē»ē»ēå®ę¹éØę·ē½ē«ļ¼å¾®č½Æå
¬åøęē®åØ2020幓1ę14ę„ä¹åę£å¼ē»ę¢åƹWindows 7ęä½ē³»ē»ēå
č“¹ęÆęćä»é£ę¶čµ·ļ¼čÆ„ē³»ē»ēēØę·å°ę ę³ę„ę¶å®å
Øę“ę°ļ¼čæåÆč½ä¼ä½æä»ä»¬ēč®”ē®ęŗ容ęåå°ē½ē»ę»å»ć"
)
@@ -132,8 +142,9 @@ def test_xprophetnet_ntg_inference(self):
tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True) for g in summary_ids_beam1
]
EXPECTED_TITLE_EN_BEAM1_TOK = "āMicrosoft āto āend āfree āsupport āfor āWindows ā7".split(" ")
- EXPECTED_TITLE_RU_BEAM1_TOK = "āMicrosoft āŠ½Š°Š¼ŠµŃŠµŠ½ Š° āŠæŃŠµŠŗŃŠ°ŃŠø ŃŃ āŠ±ŠµŃ ŠæŠ»Š°Ń Š½ŃŃ āŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ āWindows ā7 āŠæŠ¾ŃŠ»Šµ ā14 āŃŠ½Š²Š°ŃŃ ā2020 āŠ³Š¾Š“Š°".split(
- " "
+ EXPECTED_TITLE_RU_BEAM1_TOK = (
+ "āMicrosoft āŠ½Š°Š¼ŠµŃŠµŠ½ Š° āŠæŃŠµŠŗŃŠ°ŃŠø ŃŃ āŠ±ŠµŃ ŠæŠ»Š°Ń Š½ŃŃ āŠæŠ¾Š“Š“ŠµŃŠ¶ŠŗŃ āWindows ā7 āŠæŠ¾ŃŠ»Šµ ā14 āŃŠ½Š²Š°ŃŃ ā2020 āŠ³Š¾Š“Š°"
+ .split(" ")
)
EXPECTED_TITLE_ZH_BEAM1_TOK = "å¾®č½Æ å
¬åø ęē® ē»ę¢ åƹ Windows ā7 ęä½ ē³»ē»ē å
č“¹ ęÆę".split(" ")
self.assertListEqual(
diff --git a/tests/xlm_prophetnet/test_tokenization_xlm_prophetnet.py b/tests/models/xlm_prophetnet/test_tokenization_xlm_prophetnet.py
similarity index 96%
rename from tests/xlm_prophetnet/test_tokenization_xlm_prophetnet.py
rename to tests/models/xlm_prophetnet/test_tokenization_xlm_prophetnet.py
index c8f7568763109c..d560007fe3163f 100644
--- a/tests/xlm_prophetnet/test_tokenization_xlm_prophetnet.py
+++ b/tests/models/xlm_prophetnet/test_tokenization_xlm_prophetnet.py
@@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers.models.xlm_prophetnet.tokenization_xlm_prophetnet import SPIECE_UNDERLINE, XLMProphetNetTokenizer
-from transformers.testing_utils import require_sentencepiece, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
diff --git a/tests/models/xlm_roberta/__init__.py b/tests/models/xlm_roberta/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xlm_roberta/test_modeling_flax_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_flax_xlm_roberta.py
similarity index 100%
rename from tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
rename to tests/models/xlm_roberta/test_modeling_flax_xlm_roberta.py
diff --git a/tests/xlm_roberta/test_modeling_tf_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_tf_xlm_roberta.py
similarity index 100%
rename from tests/xlm_roberta/test_modeling_tf_xlm_roberta.py
rename to tests/models/xlm_roberta/test_modeling_tf_xlm_roberta.py
diff --git a/tests/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py
similarity index 100%
rename from tests/xlm_roberta/test_modeling_xlm_roberta.py
rename to tests/models/xlm_roberta/test_modeling_xlm_roberta.py
diff --git a/tests/xlm_roberta/test_tokenization_xlm_roberta.py b/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
similarity index 96%
rename from tests/xlm_roberta/test_tokenization_xlm_roberta.py
rename to tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
index 0ba1492efc099b..c8f934b258b93c 100644
--- a/tests/xlm_roberta/test_tokenization_xlm_roberta.py
+++ b/tests/models/xlm_roberta/test_tokenization_xlm_roberta.py
@@ -13,21 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import pickle
import shutil
import tempfile
import unittest
-from os.path import dirname
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
@@ -258,7 +256,10 @@ def test_tokenization_base_easy_symbols(self):
@slow
def test_tokenization_base_hard_symbols(self):
- symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth'
+ symbols = (
+ 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
+ " add words that should not exsist and be tokenized to , such as saoneuhaoesuth"
+ )
original_tokenizer_encodings = [
0,
3293,
diff --git a/tests/models/xlm_roberta_xl/__init__.py b/tests/models/xlm_roberta_xl/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py
similarity index 98%
rename from tests/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py
rename to tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py
index a3e7e64481ce0f..b889753f663c13 100644
--- a/tests/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py
+++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py
@@ -19,9 +19,9 @@
from transformers import XLMRobertaXLConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
diff --git a/tests/models/xlnet/__init__.py b/tests/models/xlnet/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/xlnet/test_modeling_tf_xlnet.py b/tests/models/xlnet/test_modeling_tf_xlnet.py
similarity index 99%
rename from tests/xlnet/test_modeling_tf_xlnet.py
rename to tests/models/xlnet/test_modeling_tf_xlnet.py
index 8cf4ca2099bd5d..dc1ca077952cf7 100644
--- a/tests/xlnet/test_modeling_tf_xlnet.py
+++ b/tests/models/xlnet/test_modeling_tf_xlnet.py
@@ -21,8 +21,8 @@
from transformers import XLNetConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
diff --git a/tests/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py
similarity index 99%
rename from tests/xlnet/test_modeling_xlnet.py
rename to tests/models/xlnet/test_modeling_xlnet.py
index 420d22cc1e20c2..dca727b299426b 100644
--- a/tests/xlnet/test_modeling_xlnet.py
+++ b/tests/models/xlnet/test_modeling_xlnet.py
@@ -19,9 +19,9 @@
from transformers import XLNetConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..generation.test_generation_utils import GenerationTesterMixin
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
+ fx_compatible = False
test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
diff --git a/tests/xlnet/test_tokenization_xlnet.py b/tests/models/xlnet/test_tokenization_xlnet.py
similarity index 97%
rename from tests/xlnet/test_tokenization_xlnet.py
rename to tests/models/xlnet/test_tokenization_xlnet.py
index 707c975201ac70..6125a1dffd7791 100644
--- a/tests/xlnet/test_tokenization_xlnet.py
+++ b/tests/models/xlnet/test_tokenization_xlnet.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import unittest
-from os.path import dirname
from transformers import SPIECE_UNDERLINE, XLNetTokenizer, XLNetTokenizerFast
-from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
-from ..test_tokenization_common import TokenizerTesterMixin
+from ...test_tokenization_common import TokenizerTesterMixin
-SAMPLE_VOCAB = os.path.join(dirname(dirname(os.path.abspath(__file__))), "fixtures/test_sentencepiece.model")
+SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_sentencepiece
diff --git a/tests/models/yolos/__init__.py b/tests/models/yolos/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/yolos/test_feature_extraction_yolos.py b/tests/models/yolos/test_feature_extraction_yolos.py
new file mode 100644
index 00000000000000..8a576a583a9af0
--- /dev/null
+++ b/tests/models/yolos/test_feature_extraction_yolos.py
@@ -0,0 +1,336 @@
+# coding=utf-8
+# Copyright 2021 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import pathlib
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision, slow
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import YolosFeatureExtractor
+
+
+class YolosFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=18,
+ max_size=1333, # by setting max_size > max_resolution we're effectively not testing this :p
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.max_size = max_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "max_size": self.max_size,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ }
+
+ def get_expected_values(self, image_inputs, batched=False):
+ """
+ This function computes the expected height and width when providing images to YolosFeatureExtractor,
+ assuming do_resize is set to True with a scalar size.
+ """
+ if not batched:
+ image = image_inputs[0]
+ if isinstance(image, Image.Image):
+ w, h = image.size
+ else:
+ h, w = image.shape[1], image.shape[2]
+ if w < h:
+ expected_height = int(self.size * h / w)
+ expected_width = self.size
+ elif w > h:
+ expected_height = self.size
+ expected_width = int(self.size * w / h)
+ else:
+ expected_height = self.size
+ expected_width = self.size
+
+ else:
+ expected_values = []
+ for image in image_inputs:
+ expected_height, expected_width = self.get_expected_values([image])
+ expected_values.append((expected_height, expected_width))
+ expected_height = max(expected_values, key=lambda item: item[0])[0]
+ expected_width = max(expected_values, key=lambda item: item[1])[1]
+
+ return expected_height, expected_width
+
+
+@require_torch
+@require_vision
+class YolosFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = YolosFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = YolosFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "max_size"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_equivalence_padding(self):
+ # Initialize feature_extractors
+ feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
+ feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test whether the method "pad" and calling the feature extractor return the same tensors
+ encoded_images_with_method = feature_extractor_1.pad(image_inputs, return_tensors="pt")
+ encoded_images = feature_extractor_2(image_inputs, return_tensors="pt")
+
+ assert torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
+
+ @slow
+ def test_call_pytorch_with_coco_detection_annotations(self):
+ # prepare image and target
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
+ target = json.loads(f.read())
+
+ target = {"image_id": 39769, "annotations": target}
+
+ # encode them
+ feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
+ encoding = feature_extractor(images=image, annotations=target, return_tensors="pt")
+
+ # verify pixel values
+ expected_shape = torch.Size([1, 3, 800, 1066])
+ self.assertEqual(encoding["pixel_values"].shape, expected_shape)
+
+ expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
+ assert torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)
+
+ # verify area
+ expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
+ assert torch.allclose(encoding["labels"][0]["area"], expected_area)
+ # verify boxes
+ expected_boxes_shape = torch.Size([6, 4])
+ self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
+ expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
+ assert torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)
+ # verify image_id
+ expected_image_id = torch.tensor([39769])
+ assert torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)
+ # verify is_crowd
+ expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
+ assert torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)
+ # verify class_labels
+ expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
+ assert torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)
+ # verify orig_size
+ expected_orig_size = torch.tensor([480, 640])
+ assert torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)
+ # verify size
+ expected_size = torch.tensor([800, 1066])
+ assert torch.allclose(encoding["labels"][0]["size"], expected_size)
+
+ @slow
+ def test_call_pytorch_with_coco_panoptic_annotations(self):
+ # prepare image, target and masks_path
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ with open("./tests/fixtures/tests_samples/COCO/coco_panoptic_annotations.txt", "r") as f:
+ target = json.loads(f.read())
+
+ target = {"file_name": "000000039769.png", "image_id": 39769, "segments_info": target}
+
+ masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic")
+
+ # encode them
+ feature_extractor = YolosFeatureExtractor(format="coco_panoptic")
+ encoding = feature_extractor(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
+
+ # verify pixel values
+ expected_shape = torch.Size([1, 3, 800, 1066])
+ self.assertEqual(encoding["pixel_values"].shape, expected_shape)
+
+ expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
+ assert torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)
+
+ # verify area
+ expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
+ assert torch.allclose(encoding["labels"][0]["area"], expected_area)
+ # verify boxes
+ expected_boxes_shape = torch.Size([6, 4])
+ self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
+ expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
+ assert torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)
+ # verify image_id
+ expected_image_id = torch.tensor([39769])
+ assert torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)
+ # verify is_crowd
+ expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
+ assert torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)
+ # verify class_labels
+ expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
+ assert torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)
+ # verify masks
+ expected_masks_sum = 822338
+ self.assertEqual(encoding["labels"][0]["masks"].sum().item(), expected_masks_sum)
+ # verify orig_size
+ expected_orig_size = torch.tensor([480, 640])
+ assert torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)
+ # verify size
+ expected_size = torch.tensor([800, 1066])
+ assert torch.allclose(encoding["labels"][0]["size"], expected_size)
diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py
new file mode 100644
index 00000000000000..75d399eaa7972e
--- /dev/null
+++ b/tests/models/yolos/test_modeling_yolos.py
@@ -0,0 +1,373 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch YOLOS model. """
+
+
+import inspect
+import unittest
+
+from transformers import YolosConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import YolosForObjectDetection, YolosModel
+ from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoFeatureExtractor
+
+
+class YolosModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=[30, 30],
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ n_targets=8,
+ num_detection_tokens=10,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.scope = scope
+ self.n_targets = n_targets
+ self.num_detection_tokens = num_detection_tokens
+ # we set the expected sequence length (which is used in several tests)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
+ image_size = to_2tuple(self.image_size)
+ patch_size = to_2tuple(self.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
+
+ labels = None
+ if self.use_labels:
+ # labels is a list of Dict (each Dict being the labels for a given example in the batch)
+ labels = []
+ for i in range(self.batch_size):
+ target = {}
+ target["class_labels"] = torch.randint(
+ high=self.num_labels, size=(self.n_targets,), device=torch_device
+ )
+ target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
+ labels.append(target)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return YolosConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ num_detection_tokens=self.num_detection_tokens,
+ num_labels=self.num_labels,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = YolosModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size)
+ )
+
+ def create_and_check_for_object_detection(self, config, pixel_values, labels):
+ model = YolosForObjectDetection(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(pixel_values=pixel_values)
+ result = model(pixel_values)
+
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_detection_tokens, self.num_labels + 1))
+ self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_detection_tokens, 4))
+
+ result = model(pixel_values=pixel_values, labels=labels)
+
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_detection_tokens, self.num_labels + 1))
+ self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_detection_tokens, 4))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class YolosModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as YOLOS does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_torchscript = False
+
+ # special case for head model
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ if return_labels:
+ if model_class.__name__ == "YolosForObjectDetection":
+ labels = []
+ for i in range(self.model_tester.batch_size):
+ target = {}
+ target["class_labels"] = torch.ones(
+ size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
+ )
+ target["boxes"] = torch.ones(
+ self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
+ )
+ labels.append(target)
+ inputs_dict["labels"] = labels
+
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = YolosModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=YolosConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_inputs_embeds(self):
+ # YOLOS does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ # in YOLOS, the seq_len is different
+ seq_len = self.model_tester.expected_seq_len
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # YOLOS has a different seq_length
+ seq_length = self.model_tester.expected_seq_len
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_for_object_detection(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_object_detection(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = YolosModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class YolosModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return AutoFeatureExtractor.from_pretrained("hustvl/yolos-small") if is_vision_available() else None
+
+ @slow
+ def test_inference_object_detection_head(self):
+ model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(inputs.pixel_values)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 100, 92))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice_logits = torch.tensor(
+ [[-24.0248, -10.3024, -14.8290], [-42.0392, -16.8200, -27.4334], [-27.2743, -11.8154, -18.7148]],
+ device=torch_device,
+ )
+ expected_slice_boxes = torch.tensor(
+ [[0.2559, 0.5455, 0.4706], [0.2989, 0.7279, 0.1875], [0.7732, 0.4017, 0.4462]], device=torch_device
+ )
+ self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
+ self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
diff --git a/tests/models/yoso/__init__.py b/tests/models/yoso/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/yoso/test_modeling_yoso.py b/tests/models/yoso/test_modeling_yoso.py
similarity index 98%
rename from tests/yoso/test_modeling_yoso.py
rename to tests/models/yoso/test_modeling_yoso.py
index f6d013b1bf8cd4..0a0749dd7d9bcd 100644
--- a/tests/yoso/test_modeling_yoso.py
+++ b/tests/models/yoso/test_modeling_yoso.py
@@ -20,8 +20,8 @@
from transformers import YosoConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@@ -126,6 +126,11 @@ def get_config(self):
initializer_range=self.initializer_range,
)
+ def get_pipeline_config(self):
+ config = self.get_config()
+ config.vocab_size = 300
+ return config
+
def prepare_config_and_inputs_for_decoder(self):
(
config,
diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py
index ea5a54763932c9..3430bc9fbd81d4 100644
--- a/tests/onnx/test_onnx_v2.py
+++ b/tests/onnx/test_onnx_v2.py
@@ -6,7 +6,7 @@
import pytest
from parameterized import parameterized
-from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
+from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
@@ -15,8 +15,12 @@
export,
validate_model_outputs,
)
-from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
-from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
+from transformers.onnx.utils import (
+ compute_effective_axis_dimension,
+ compute_serialized_parameters_size,
+ get_preprocessor,
+)
+from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
if is_torch_available() or is_tf_available():
@@ -176,16 +180,23 @@ def test_values_override(self):
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
+ ("convnext", "facebook/convnext-tiny-224"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
+ ("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
+ ("squeezebert", "squeezebert/squeezebert-uncased"),
+ ("mobilebert", "google/mobilebert-uncased"),
+ ("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
+ ("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
+ ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
@@ -201,6 +212,9 @@ def test_values_override(self):
("m2m-100", "facebook/m2m100_418M"),
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
+ ("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
+ ("longt5", "google/long-t5-local-base"),
+ ("longt5", "google/long-t5-tglobal-base"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
@@ -222,10 +236,15 @@ def test_values_override(self):
def _get_models_to_test(export_models_list):
models_to_test = []
if is_torch_available() or is_tf_available():
- for (name, model) in export_models_list:
- for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
- name
- ).items():
+ for name, model, *features in export_models_list:
+ if features:
+ feature_config_mapping = {
+ feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
+ }
+ else:
+ feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
+
+ for feature, onnx_config_class_constructor in feature_config_mapping.items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
else:
@@ -240,7 +259,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""
- def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
+ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
@@ -253,24 +272,20 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
if torch_version < onnx_config.torch_onnx_minimum_version:
pytest.skip(
- f"Skipping due to incompatible PyTorch version. Minimum required is {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
+ "Skipping due to incompatible PyTorch version. Minimum required is"
+ f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
- # Check the modality of the inputs and instantiate the appropriate preprocessor
- if model.main_input_name == "input_ids":
- preprocessor = AutoTokenizer.from_pretrained(model_name)
- # Useful for causal lm models that do not use pad tokens.
- if not getattr(config, "pad_token_id", None):
- config.pad_token_id = preprocessor.eos_token_id
- elif model.main_input_name == "pixel_values":
- preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
- else:
- raise ValueError(f"Unsupported model input name: {model.main_input_name}")
+ preprocessor = get_preprocessor(model_name)
+
+ # Useful for causal lm models that do not use pad tokens.
+ if isinstance(preprocessor, PreTrainedTokenizerBase) and not getattr(config, "pad_token_id", None):
+ config.pad_token_id = preprocessor.eos_token_id
with NamedTemporaryFile("w") as output:
try:
onnx_inputs, onnx_outputs = export(
- preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
+ preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
)
validate_model_outputs(
onnx_config,
@@ -287,9 +302,18 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
@slow
@require_torch
@require_vision
+ @require_rjieba
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
+ @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
+ @slow
+ @require_torch
+ @require_vision
+ @require_rjieba
+ def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
+ self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
+
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch
diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py
index ec54055d7d62a7..25bf520eafb4d7 100644
--- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py
+++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py
@@ -184,7 +184,9 @@ def test_large_model_pt_with_lm(self):
self.assertEqual(
output,
{
- "text": "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajre"
+ "text": (
+ "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajre"
+ )
},
)
@@ -194,7 +196,9 @@ def test_large_model_pt_with_lm(self):
self.assertEqual(
output,
{
- "text": "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajcri",
+ "text": (
+ "y en las ramas medio sumergidas revoloteaban algunos pƔjaros de quimƩrico y legendario plumajcri"
+ ),
"chunks": [
{"text": "y", "timestamp": (0.52, 0.54)},
{"text": "en", "timestamp": (0.6, 0.68)},
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index fe0dbd28153b71..6a6c8b73e52770 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -65,7 +65,7 @@ def get_tiny_config_from_class(configuration_class):
try:
model_slug = model_type.replace("-", "_")
- module = importlib.import_module(f".test_modeling_{model_slug}", package=f"tests.{model_slug}")
+ module = importlib.import_module(f".test_modeling_{model_slug}", package=f"tests.models.{model_slug}")
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
except (ImportError, AttributeError):
logger.error(f"No model tester class for {configuration_class.__name__}")
@@ -184,7 +184,8 @@ def test(self):
if tokenizer is None and feature_extractor is None:
self.skipTest(
- f"Ignoring {ModelClass}, cannot create a tokenizer or feature_extractor (PerceiverConfig with no FastTokenizer ?)"
+ f"Ignoring {ModelClass}, cannot create a tokenizer or feature_extractor (PerceiverConfig with"
+ " no FastTokenizer ?)"
)
pipeline, examples = self.get_test_pipeline(model, tokenizer, feature_extractor)
if pipeline is None:
diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py
index ed551bf6f4903f..d85ab8d7ce32a6 100644
--- a/tests/pipelines/test_pipelines_fill_mask.py
+++ b/tests/pipelines/test_pipelines_fill_mask.py
@@ -16,7 +16,14 @@
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
from transformers.pipelines import PipelineException
-from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
+from transformers.testing_utils import (
+ is_pipeline_test,
+ nested_simplify,
+ require_tf,
+ require_torch,
+ require_torch_gpu,
+ slow,
+)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -130,6 +137,19 @@ def test_small_model_pt(self):
],
)
+ @require_torch_gpu
+ def test_fp16_casting(self):
+ pipe = pipeline("fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=0, framework="pt")
+
+ # convert model to fp16
+ pipe.model.half()
+
+ response = pipe("Paris is the [MASK] of France.")
+ # We actually don't care about the result, we just want to make sure
+ # it works, meaning the float16 tensor got casted back to float32
+ # for postprocessing.
+ self.assertIsInstance(response, list)
+
@slow
@require_torch
def test_large_model_pt(self):
diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py
index e37fa12776835f..f34237612c11a9 100644
--- a/tests/pipelines/test_pipelines_question_answering.py
+++ b/tests/pipelines/test_pipelines_question_answering.py
@@ -106,17 +106,59 @@ def run_pipeline_test(self, question_answerer, _):
)
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
+ # Using batch is OK
+ new_outputs = question_answerer(
+ question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
+ )
+ self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
+ self.assertEqual(outputs, new_outputs)
+
@require_torch
def test_small_model_pt(self):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
)
+
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
+ @require_torch
+ def test_small_model_pt_softmax_trick(self):
+ question_answerer = pipeline(
+ "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
+ )
+
+ real_postprocess = question_answerer.postprocess
+
+ # Tweak start and stop to make sure we encounter the softmax logits
+ # bug.
+ def ensure_large_logits_postprocess(
+ model_outputs,
+ top_k=1,
+ handle_impossible_answer=False,
+ max_answer_len=15,
+ ):
+ for output in model_outputs:
+ output["start"] = output["start"] * 1e6
+ output["end"] = output["end"] * 1e6
+ return real_postprocess(
+ model_outputs,
+ top_k=top_k,
+ handle_impossible_answer=handle_impossible_answer,
+ max_answer_len=max_answer_len,
+ )
+
+ question_answerer.postprocess = ensure_large_logits_postprocess
+
+ outputs = question_answerer(
+ question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
+ )
+
+ self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
+
@slow
@require_torch
def test_small_model_long_context_cls_slow(self):
@@ -164,7 +206,42 @@ def test_large_model_issue(self):
)
outputs = qa_pipeline(
{
- "context": "Yes Bank founder Rana Kapoor has approached the Bombay High Court, challenging a special court's order from August this year that had remanded him in police custody for a week in a multi-crore loan fraud case. Kapoor, who is currently lodged in Taloja Jail, is an accused in the loan fraud case and some related matters being probed by the CBI and Enforcement Directorate. A single bench presided over by Justice S K Shinde on Tuesday posted the plea for further hearing on October 14. In his plea filed through advocate Vijay Agarwal, Kapoor claimed that the special court's order permitting the CBI's request for police custody on August 14 was illegal and in breach of the due process of law. Therefore, his police custody and subsequent judicial custody in the case were all illegal. Kapoor has urged the High Court to quash and set aside the special court's order dated August 14. As per his plea, in August this year, the CBI had moved two applications before the special court, one seeking permission to arrest Kapoor, who was already in judicial custody at the time in another case, and the other, seeking his police custody. While the special court refused to grant permission to the CBI to arrest Kapoor, it granted the central agency's plea for his custody. Kapoor, however, said in his plea that before filing an application for his arrest, the CBI had not followed the process of issuing him a notice under Section 41 of the CrPC for appearance before it. He further said that the CBI had not taken prior sanction as mandated under section 17 A of the Prevention of Corruption Act for prosecuting him. The special court, however, had said in its order at the time that as Kapoor was already in judicial custody in another case and was not a free man the procedure mandated under Section 41 of the CrPC need not have been adhered to as far as issuing a prior notice of appearance was concerned. ADVERTISING It had also said that case records showed that the investigating officer had taken an approval from a managing director of Yes Bank before beginning the proceedings against Kapoor and such a permission was a valid sanction. However, Kapoor in his plea said that the above order was bad in law and sought that it be quashed and set aside. The law mandated that if initial action was not in consonance with legal procedures, then all subsequent actions must be held as illegal, he said, urging the High Court to declare the CBI remand and custody and all subsequent proceedings including the further custody as illegal and void ab-initio. In a separate plea before the High Court, Kapoor's daughter Rakhee Kapoor-Tandon has sought exemption from in-person appearance before a special PMLA court. Rakhee has stated that she is a resident of the United Kingdom and is unable to travel to India owing to restrictions imposed due to the COVID-19 pandemic. According to the CBI, in the present case, Kapoor had obtained a gratification or pecuniary advantage of ā¹ 307 crore, and thereby caused Yes Bank a loss of ā¹ 1,800 crore by extending credit facilities to Avantha Group, when it was not eligible for the same",
+ "context": (
+ "Yes Bank founder Rana Kapoor has approached the Bombay High Court, challenging a special court's"
+ " order from August this year that had remanded him in police custody for a week in a multi-crore"
+ " loan fraud case. Kapoor, who is currently lodged in Taloja Jail, is an accused in the loan fraud"
+ " case and some related matters being probed by the CBI and Enforcement Directorate. A single"
+ " bench presided over by Justice S K Shinde on Tuesday posted the plea for further hearing on"
+ " October 14. In his plea filed through advocate Vijay Agarwal, Kapoor claimed that the special"
+ " court's order permitting the CBI's request for police custody on August 14 was illegal and in"
+ " breach of the due process of law. Therefore, his police custody and subsequent judicial custody"
+ " in the case were all illegal. Kapoor has urged the High Court to quash and set aside the special"
+ " court's order dated August 14. As per his plea, in August this year, the CBI had moved two"
+ " applications before the special court, one seeking permission to arrest Kapoor, who was already"
+ " in judicial custody at the time in another case, and the other, seeking his police custody."
+ " While the special court refused to grant permission to the CBI to arrest Kapoor, it granted the"
+ " central agency's plea for his custody. Kapoor, however, said in his plea that before filing an"
+ " application for his arrest, the CBI had not followed the process of issuing him a notice under"
+ " Section 41 of the CrPC for appearance before it. He further said that the CBI had not taken"
+ " prior sanction as mandated under section 17 A of the Prevention of Corruption Act for"
+ " prosecuting him. The special court, however, had said in its order at the time that as Kapoor"
+ " was already in judicial custody in another case and was not a free man the procedure mandated"
+ " under Section 41 of the CrPC need not have been adhered to as far as issuing a prior notice of"
+ " appearance was concerned. ADVERTISING It had also said that case records showed that the"
+ " investigating officer had taken an approval from a managing director of Yes Bank before"
+ " beginning the proceedings against Kapoor and such a permission was a valid sanction. However,"
+ " Kapoor in his plea said that the above order was bad in law and sought that it be quashed and"
+ " set aside. The law mandated that if initial action was not in consonance with legal procedures,"
+ " then all subsequent actions must be held as illegal, he said, urging the High Court to declare"
+ " the CBI remand and custody and all subsequent proceedings including the further custody as"
+ " illegal and void ab-initio. In a separate plea before the High Court, Kapoor's daughter Rakhee"
+ " Kapoor-Tandon has sought exemption from in-person appearance before a special PMLA court. Rakhee"
+ " has stated that she is a resident of the United Kingdom and is unable to travel to India owing"
+ " to restrictions imposed due to the COVID-19 pandemic. According to the CBI, in the present case,"
+ " Kapoor had obtained a gratification or pecuniary advantage of ā¹ 307 crore, and thereby caused"
+ " Yes Bank a loss of ā¹ 1,800 crore by extending credit facilities to Avantha Group, when it was"
+ " not eligible for the same"
+ ),
"question": "Is this person invovled in fraud?",
}
)
diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py
index e434ed742dc70f..d797383811c6ae 100644
--- a/tests/pipelines/test_pipelines_summarization.py
+++ b/tests/pipelines/test_pipelines_summarization.py
@@ -18,6 +18,7 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
LEDConfig,
+ LongT5Config,
SummarizationPipeline,
T5Config,
pipeline,
@@ -54,8 +55,8 @@ def run_pipeline_test(self, summarizer, _):
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
- if not isinstance(model.config, (T5Config, LEDConfig)):
- # LED, T5 can handle it.
+ if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
+ # LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
@@ -91,7 +92,49 @@ def test_small_model_tf(self):
@slow
def test_integration_torch_summarization(self):
summarizer = pipeline(task="summarization", device=DEFAULT_DEVICE_NUM)
- cnn_article = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
- expected_cnn_summary = " The Palestinian Authority becomes the 123rd member of the International Criminal Court . The move gives the court jurisdiction over alleged crimes in Palestinian territories . Israel and the United States opposed the Palestinians' efforts to join the court . Rights group Human Rights Watch welcomes the move, says governments seeking to penalize Palestine should end pressure ."
+ cnn_article = (
+ " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
+ " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
+ " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
+ " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
+ ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
+ ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
+ " situation in Palestinian territories, paving the way for possible war crimes investigations against"
+ " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
+ " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
+ " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
+ ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
+ ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
+ ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
+ " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
+ ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
+ " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
+ ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
+ ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
+ " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
+ ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
+ " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
+ ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
+ " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
+ ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
+ " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
+ ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
+ ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
+ ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
+ " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
+ ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
+ " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
+ ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
+ " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
+ " will include alleged war crimes committed since June. The International Criminal Court was set up in"
+ " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
+ " and Faith Karimi contributed to this report."
+ )
+ expected_cnn_summary = (
+ " The Palestinian Authority becomes the 123rd member of the International Criminal Court . The move gives"
+ " the court jurisdiction over alleged crimes in Palestinian territories . Israel and the United States"
+ " opposed the Palestinians' efforts to join the court . Rights group Human Rights Watch welcomes the move,"
+ " says governments seeking to penalize Palestine should end pressure ."
+ )
result = summarizer(cnn_article)
self.assertEqual(result[0]["summary_text"], expected_cnn_summary)
diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py
index 86bbf991b03922..ba7fdaa75c5017 100644
--- a/tests/pipelines/test_pipelines_table_question_answering.py
+++ b/tests/pipelines/test_pipelines_table_question_answering.py
@@ -92,7 +92,8 @@ def test_small_model_tf(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -194,7 +195,8 @@ def test_small_model_pt(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -313,7 +315,8 @@ def test_slow_tokenizer_sqa_pt(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
@@ -434,7 +437,8 @@ def test_slow_tokenizer_sqa_tf(self):
},
query=[
"What repository has the largest number of stars?",
- "Given that the numbers of stars defines if a repository is active, what repository is the most active?",
+ "Given that the numbers of stars defines if a repository is active, what repository is the most"
+ " active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py
index 39deed9bee55c9..9251b299224c52 100644
--- a/tests/pipelines/test_pipelines_text_classification.py
+++ b/tests/pipelines/test_pipelines_text_classification.py
@@ -39,6 +39,41 @@ def test_small_model_pt(self):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+ outputs = text_classifier("This is great !", top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
+ )
+
+ outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
+ ],
+ )
+
+ outputs = text_classifier("This is great !", top_k=1)
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
+ # Legacy behavior
+ outputs = text_classifier("This is great !", return_all_scores=False)
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
+ @require_torch
+ def test_accepts_torch_device(self):
+ import torch
+
+ text_classifier = pipeline(
+ task="text-classification",
+ model="hf-internal-testing/tiny-random-distilbert",
+ framework="pt",
+ device=torch.device("cpu"),
+ )
+
+ outputs = text_classifier("This is great !")
+ self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
+
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(
@@ -93,3 +128,37 @@ def run_pipeline_test(self, text_classifier, _):
)
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
+
+ # Forcing to get all results with `top_k=None`
+ # This is NOT the legacy format
+ outputs = text_classifier(valid_inputs, top_k=None)
+ N = len(model.config.id2label.values())
+ self.assertEqual(
+ nested_simplify(outputs),
+ [[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
+ )
+
+ valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
+ outputs = text_classifier(valid_inputs)
+ self.assertEqual(
+ nested_simplify(outputs),
+ {"label": ANY(str), "score": ANY(float)},
+ )
+ self.assertTrue(outputs["label"] in model.config.id2label.values())
+
+ # This might be used a text pair, but tokenizer + pipe interaction
+ # makes it hard to understand that it's not using the pair properly
+ # https://github.com/huggingface/transformers/issues/17305
+ # We disabled this usage instead as it was outputting wrong outputs.
+ invalid_input = [["HuggingFace is in ", "Paris is in France"]]
+ with self.assertRaises(ValueError):
+ text_classifier(invalid_input)
+
+ # This used to be valid for doing text pairs
+ # We're keeping it working because of backward compatibility
+ outputs = text_classifier([[["HuggingFace is in ", "Paris is in France"]]])
+ self.assertEqual(
+ nested_simplify(outputs),
+ [{"label": ANY(str), "score": ANY(float)}],
+ )
+ self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py
index ca67c3bea13d75..929e2732f092fc 100644
--- a/tests/pipelines/test_pipelines_text_generation.py
+++ b/tests/pipelines/test_pipelines_text_generation.py
@@ -34,7 +34,10 @@ def test_small_model_pt(self):
outputs,
[
{
- "generated_text": "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope."
+ " oscope. FiliFili@@"
+ )
}
],
)
@@ -45,12 +48,18 @@ def test_small_model_pt(self):
[
[
{
- "generated_text": "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a test ā ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope."
+ " oscope. FiliFili@@"
+ )
}
],
[
{
- "generated_text": "This is a second test ā segmental segmental segmental 议议eski eski flutter flutter Lacy oscope. oscope. FiliFili@@"
+ "generated_text": (
+ "This is a second test ā segmental segmental segmental 议议eski eski flutter flutter Lacy"
+ " oscope. oscope. FiliFili@@"
+ )
}
],
],
@@ -97,7 +106,10 @@ def test_small_model_tf(self):
outputs,
[
{
- "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ"
+ " please,"
+ )
}
],
)
@@ -108,12 +120,18 @@ def test_small_model_tf(self):
[
[
{
- "generated_text": "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a test FeyFeyFey(Croatis.), s.), Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ"
+ " please,"
+ )
}
],
[
{
- "generated_text": "This is a second test Chieftain Chieftain prefecture prefecture prefecture Cannes Cannes Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ "generated_text": (
+ "This is a second test Chieftain Chieftain prefecture prefecture prefecture Cannes Cannes"
+ " Cannes é²é²Cannes Cannes Cannes ęµ please,"
+ )
}
],
],
diff --git a/tests/pipelines/test_pipelines_translation.py b/tests/pipelines/test_pipelines_translation.py
index 368f6bc9c5cc79..3c5999f36e60dc 100644
--- a/tests/pipelines/test_pipelines_translation.py
+++ b/tests/pipelines/test_pipelines_translation.py
@@ -61,7 +61,10 @@ def test_small_model_pt(self):
outputs,
[
{
- "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ "translation_text": (
+ "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ " Beide Beide"
+ )
}
],
)
@@ -74,7 +77,10 @@ def test_small_model_tf(self):
outputs,
[
{
- "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ "translation_text": (
+ "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide"
+ " Beide Beide"
+ )
}
],
)
@@ -87,7 +93,10 @@ def test_en_to_de_pt(self):
outputs,
[
{
- "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
+ "translation_text": (
+ "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine"
+ " urine urine urine urine urine urine urine"
+ )
}
],
)
@@ -100,7 +109,10 @@ def test_en_to_de_tf(self):
outputs,
[
{
- "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine"
+ "translation_text": (
+ "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine"
+ " urine urine urine urine urine urine urine"
+ )
}
],
)
diff --git a/tests/pipelines/test_pipelines_visual_question_answering.py b/tests/pipelines/test_pipelines_visual_question_answering.py
new file mode 100644
index 00000000000000..d3315681f47ebb
--- /dev/null
+++ b/tests/pipelines/test_pipelines_visual_question_answering.py
@@ -0,0 +1,115 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_available
+from transformers.pipelines import pipeline
+from transformers.testing_utils import (
+ is_pipeline_test,
+ nested_simplify,
+ require_tf,
+ require_torch,
+ require_vision,
+ slow,
+)
+
+from .test_pipelines_common import ANY, PipelineTestCaseMeta
+
+
+if is_vision_available():
+ from PIL import Image
+else:
+
+ class Image:
+ @staticmethod
+ def open(*args, **kwargs):
+ pass
+
+
+@is_pipeline_test
+@require_torch
+@require_vision
+class VisualQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
+ model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
+
+ def get_test_pipeline(self, model, tokenizer, feature_extractor):
+ vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
+ examples = [
+ {
+ "image": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
+ "question": "How many cats are there?",
+ },
+ {
+ "image": "./tests/fixtures/tests_samples/COCO/000000039769.png",
+ "question": "How many cats are there?",
+ },
+ ]
+ return vqa_pipeline, examples
+
+ def run_pipeline_test(self, vqa_pipeline, examples):
+ outputs = vqa_pipeline(examples, top_k=1)
+ self.assertEqual(
+ outputs,
+ [
+ [{"score": ANY(float), "answer": ANY(str)}],
+ [{"score": ANY(float), "answer": ANY(str)}],
+ ],
+ )
+
+ @require_torch
+ def test_small_model_pt(self):
+ vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
+ image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
+ question = "How many cats are there?"
+
+ outputs = vqa_pipeline(image=image, question="How many cats are there?", top_k=2)
+ self.assertEqual(
+ outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
+ )
+
+ outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
+ self.assertEqual(
+ outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
+ )
+
+ @slow
+ @require_torch
+ def test_large_model_pt(self):
+ vqa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
+ image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
+ question = "How many cats are there?"
+
+ outputs = vqa_pipeline(image=image, question=question, top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
+ )
+
+ outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
+ )
+
+ outputs = vqa_pipeline(
+ [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
+ )
+ self.assertEqual(
+ nested_simplify(outputs, decimals=4),
+ [[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
+ )
+
+ @require_tf
+ @unittest.skip("Visual question answering not implemented in TF")
+ def test_small_model_tf(self):
+ pass
diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py
index ed564581e5260e..af98ac02017205 100644
--- a/tests/pipelines/test_pipelines_zero_shot.py
+++ b/tests/pipelines/test_pipelines_zero_shot.py
@@ -202,14 +202,39 @@ def test_large_model_pt(self):
},
)
outputs = zero_shot_classifier(
- "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
+ " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
+ " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
+ " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
+ " machine translation tasks show these models to be superior in quality while being more parallelizable"
+ " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
+ " English-to-German translation task, improving over the existing best results, including ensembles by"
+ " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
+ " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
+ " fraction of the training costs of the best models from the literature. We show that the Transformer"
+ " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
+ " large and limited training data.",
candidate_labels=["machine learning", "statistics", "translation", "vision"],
multi_label=True,
)
self.assertEqual(
nested_simplify(outputs),
{
- "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "sequence": (
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
+ " networks in an encoder-decoder configuration. The best performing models also connect the"
+ " encoder and decoder through an attention mechanism. We propose a new simple network"
+ " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
+ " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
+ " superior in quality while being more parallelizable and requiring significantly less time to"
+ " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
+ " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
+ " English-to-French translation task, our model establishes a new single-model state-of-the-art"
+ " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
+ " costs of the best models from the literature. We show that the Transformer generalizes well to"
+ " other tasks by applying it successfully to English constituency parsing both with large and"
+ " limited training data."
+ ),
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.713, 0.018, 0.018],
},
@@ -232,14 +257,39 @@ def test_large_model_tf(self):
},
)
outputs = zero_shot_classifier(
- "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
+ " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
+ " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
+ " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
+ " machine translation tasks show these models to be superior in quality while being more parallelizable"
+ " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
+ " English-to-German translation task, improving over the existing best results, including ensembles by"
+ " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
+ " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
+ " fraction of the training costs of the best models from the literature. We show that the Transformer"
+ " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
+ " large and limited training data.",
candidate_labels=["machine learning", "statistics", "translation", "vision"],
multi_label=True,
)
self.assertEqual(
nested_simplify(outputs),
{
- "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
+ "sequence": (
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
+ " networks in an encoder-decoder configuration. The best performing models also connect the"
+ " encoder and decoder through an attention mechanism. We propose a new simple network"
+ " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
+ " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
+ " superior in quality while being more parallelizable and requiring significantly less time to"
+ " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
+ " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
+ " English-to-French translation task, our model establishes a new single-model state-of-the-art"
+ " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
+ " costs of the best models from the literature. We show that the Transformer generalizes well to"
+ " other tasks by applying it successfully to English constituency parsing both with large and"
+ " limited training data."
+ ),
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.713, 0.018, 0.018],
},
diff --git a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
index 6bec48fda7adcc..534b1656d10f3e 100644
--- a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
+++ b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py
@@ -81,8 +81,10 @@ class DataTrainingArguments:
max_seq_length: int = field(
default=128,
metadata={
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
},
)
overwrite_cache: bool = field(
@@ -91,29 +93,37 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=True,
metadata={
- "help": "Whether to pad all samples to `max_seq_length`. "
- "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ "help": (
+ "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ )
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ )
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
- "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
- "value if set."
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of test examples to this "
+ "value if set."
+ )
},
)
train_file: Optional[str] = field(
@@ -170,8 +180,10 @@ class ModelArguments:
use_auth_token: bool = field(
default=False,
metadata={
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
- "with private models)."
+ "help": (
+ "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ )
},
)
diff --git a/tests/splinter/test_modeling_splinter.py b/tests/splinter/test_modeling_splinter.py
deleted file mode 100644
index 1b6cfeac95378c..00000000000000
--- a/tests/splinter/test_modeling_splinter.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# coding=utf-8
-# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" Testing suite for the PyTorch Splinter model. """
-
-
-import unittest
-
-from transformers import is_torch_available
-from transformers.testing_utils import require_torch, slow, torch_device
-
-from ..test_configuration_common import ConfigTester
-from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
-
-
-if is_torch_available():
- import torch
-
- from transformers import SplinterConfig, SplinterForQuestionAnswering, SplinterModel
- from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST
-
-
-class SplinterModelTester:
- def __init__(
- self,
- parent,
- batch_size=13,
- seq_length=7,
- is_training=True,
- use_input_mask=True,
- use_token_type_ids=True,
- use_labels=True,
- vocab_size=99,
- hidden_size=32,
- num_hidden_layers=5,
- num_attention_heads=4,
- intermediate_size=37,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=16,
- type_sequence_label_size=2,
- initializer_range=0.02,
- num_labels=3,
- num_choices=4,
- scope=None,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.seq_length = seq_length
- self.is_training = is_training
- self.use_input_mask = use_input_mask
- self.use_token_type_ids = use_token_type_ids
- self.use_labels = use_labels
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_act = hidden_act
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.type_sequence_label_size = type_sequence_label_size
- self.initializer_range = initializer_range
- self.num_labels = num_labels
- self.num_choices = num_choices
- self.scope = scope
-
- def prepare_config_and_inputs(self):
- input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
-
- input_mask = None
- if self.use_input_mask:
- input_mask = random_attention_mask([self.batch_size, self.seq_length])
-
- token_type_ids = None
- if self.use_token_type_ids:
- token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
-
- sequence_labels = None
- token_labels = None
- choice_labels = None
- if self.use_labels:
- sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
- token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
- choice_labels = ids_tensor([self.batch_size], self.num_choices)
-
- config = SplinterConfig(
- vocab_size=self.vocab_size,
- hidden_size=self.hidden_size,
- num_hidden_layers=self.num_hidden_layers,
- num_attention_heads=self.num_attention_heads,
- intermediate_size=self.intermediate_size,
- hidden_act=self.hidden_act,
- hidden_dropout_prob=self.hidden_dropout_prob,
- attention_probs_dropout_prob=self.attention_probs_dropout_prob,
- max_position_embeddings=self.max_position_embeddings,
- type_vocab_size=self.type_vocab_size,
- is_decoder=False,
- initializer_range=self.initializer_range,
- )
-
- return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
-
- def create_and_check_model(
- self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
- ):
- model = SplinterModel(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
- result = model(input_ids, token_type_ids=token_type_ids)
- result = model(input_ids)
- self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
-
- def create_and_check_for_question_answering(
- self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
- ):
- model = SplinterForQuestionAnswering(config=config)
- model.to(torch_device)
- model.eval()
- result = model(
- input_ids,
- attention_mask=input_mask,
- token_type_ids=token_type_ids,
- start_positions=sequence_labels,
- end_positions=sequence_labels,
- )
- self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
- self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- (
- config,
- input_ids,
- token_type_ids,
- input_mask,
- sequence_labels,
- token_labels,
- choice_labels,
- ) = config_and_inputs
- inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
- return config, inputs_dict
-
-
-@require_torch
-class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
-
- all_model_classes = (
- (
- SplinterModel,
- SplinterForQuestionAnswering,
- )
- if is_torch_available()
- else ()
- )
-
- def setUp(self):
- self.model_tester = SplinterModelTester(self)
- self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_model(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_model(*config_and_inputs)
-
- def test_model_various_embeddings(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- for type in ["absolute", "relative_key", "relative_key_query"]:
- config_and_inputs[0].position_embedding_type = type
- self.model_tester.create_and_check_model(*config_and_inputs)
-
- def test_for_question_answering(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
-
- @slow
- def test_model_from_pretrained(self):
- for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
- model = SplinterModel.from_pretrained(model_name)
- self.assertIsNotNone(model)
-
-
-@require_torch
-class SplinterModelIntegrationTest(unittest.TestCase):
- @slow
- def test_splinter_question_answering(self):
- model = SplinterForQuestionAnswering.from_pretrained("tau/splinter-base-qass")
-
- # Input: "[CLS] Brad was born in [QUESTION] . He returned to the United Kingdom later . [SEP]"
- # Output should be the span "the United Kingdom"
- input_ids = torch.tensor(
- [[101, 7796, 1108, 1255, 1107, 104, 119, 1124, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
- )
- output = model(input_ids)
-
- expected_shape = torch.Size((1, 16))
- self.assertEqual(output.start_logits.shape, expected_shape)
- self.assertEqual(output.end_logits.shape, expected_shape)
-
- self.assertEqual(torch.argmax(output.start_logits), 10)
- self.assertEqual(torch.argmax(output.end_logits), 12)
diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py
index d17ff540679230..93723d11ba2b43 100644
--- a/tests/test_configuration_common.py
+++ b/tests/test_configuration_common.py
@@ -300,8 +300,9 @@ def test_config_common_kwargs_is_complete(self):
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
- "The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
- f"pick another value for them: {', '.join(keys_with_defaults)}."
+ "The following keys are set with the default values in"
+ " `test_configuration_common.config_common_kwargs` pick another value for them:"
+ f" {', '.join(keys_with_defaults)}."
)
def test_cached_files_are_used_when_internet_is_down(self):
@@ -356,7 +357,7 @@ def test_repo_versioning_before(self):
)
self.assertEqual(new_configuration.hidden_size, 2)
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
- self.assertDictEqual(kwargs, {"_from_auto": True})
+ self.assertDictEqual(kwargs, {})
# Testing an older version by monkey-patching the version in the module it's used.
import transformers as old_transformers
diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py
index 3f7abcaa70c20a..4de2cb3b8bcb11 100644
--- a/tests/test_feature_extraction_common.py
+++ b/tests/test_feature_extraction_common.py
@@ -25,7 +25,7 @@
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
-from transformers.testing_utils import PASS, USER, is_staging_test
+from transformers.testing_utils import PASS, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.utils import is_torch_available, is_vision_available
@@ -42,7 +42,7 @@
from PIL import Image
-SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
+SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
@@ -107,7 +107,8 @@ def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
- feat_extract_first.save_pretrained(tmpdirname)
+ saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict())
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index ac45a1c10822c4..7bdb4a0590edc1 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -19,6 +19,7 @@
import json
import os
import os.path
+import pickle
import random
import sys
import tempfile
@@ -50,7 +51,9 @@
is_pt_flax_cross_test,
is_pt_tf_cross_test,
is_staging_test,
+ require_accelerate,
require_torch,
+ require_torch_gpu,
require_torch_multi_gpu,
require_usr_bin_time,
slow,
@@ -59,6 +62,7 @@
from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
+ is_accelerate_available,
is_flax_available,
is_tf_available,
is_torch_fx_available,
@@ -71,6 +75,10 @@
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
+if is_accelerate_available():
+ from accelerate.utils import compute_module_sizes
+
+
if is_torch_available():
import torch
from torch import nn
@@ -93,6 +101,8 @@
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
+ AutoModelForCausalLM,
+ AutoTokenizer,
BertConfig,
BertModel,
PreTrainedModel,
@@ -124,6 +134,7 @@ def _config_zero_init(config):
TINY_T5 = "patrickvonplaten/t5-tiny-random"
+TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@require_torch
@@ -143,6 +154,7 @@ class ModelTesterMixin:
test_model_parallel = False
is_encoder_decoder = False
has_attentions = True
+ model_split_percents = [0.5, 0.7, 0.9]
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
@@ -474,123 +486,119 @@ def test_training_gradient_checkpointing(self):
loss.backward()
def test_attention_outputs(self):
- if not self.has_attentions:
- pass
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
- else:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+ chunk_length = getattr(self.model_tester, "chunk_length", None)
+ if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
+ encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- seq_len = getattr(self.model_tester, "seq_length", None)
- decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
- chunk_length = getattr(self.model_tester, "chunk_length", None)
- if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
- encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
- self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
-
- if chunk_length is not None:
- self.assertListEqual(
- list(attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- if self.is_encoder_decoder:
- correct_outlen = 5
-
- # loss is at first position
- if "labels" in inputs_dict:
- correct_outlen += 1 # loss is added to beginning
- # Question Answering model returns start_logits and end_logits
- if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
- correct_outlen += 1 # start_logits and end_logits instead of only 1 output
- if "past_key_values" in outputs:
- correct_outlen += 1 # past_key_values have been returned
-
- self.assertEqual(out_len, correct_outlen)
-
- # decoder attentions
- decoder_attentions = outputs.decoder_attentions
- self.assertIsInstance(decoder_attentions, (list, tuple))
- self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(decoder_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
- )
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
- # cross attentions
- cross_attentions = outputs.cross_attentions
- self.assertIsInstance(cross_attentions, (list, tuple))
- self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
- self.assertListEqual(
- list(cross_attentions[0].shape[-3:]),
- [
- self.model_tester.num_attention_heads,
- decoder_seq_length,
- encoder_key_length,
- ],
- )
+ if chunk_length is not None:
+ self.assertListEqual(
+ list(attentions[0].shape[-4:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
+ )
+ else:
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+ out_len = len(outputs)
+
+ if self.is_encoder_decoder:
+ correct_outlen = 5
+
+ # loss is at first position
+ if "labels" in inputs_dict:
+ correct_outlen += 1 # loss is added to beginning
+ # Question Answering model returns start_logits and end_logits
+ if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
+ if "past_key_values" in outputs:
+ correct_outlen += 1 # past_key_values have been returned
+
+ self.assertEqual(out_len, correct_outlen)
+
+ # decoder attentions
+ decoder_attentions = outputs.decoder_attentions
+ self.assertIsInstance(decoder_attentions, (list, tuple))
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(decoder_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
+ )
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ # cross attentions
+ cross_attentions = outputs.cross_attentions
+ self.assertIsInstance(cross_attentions, (list, tuple))
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(cross_attentions[0].shape[-3:]),
+ [
+ self.model_tester.num_attention_heads,
+ decoder_seq_length,
+ encoder_key_length,
+ ],
+ )
- if hasattr(self.model_tester, "num_hidden_states_types"):
- added_hidden_states = self.model_tester.num_hidden_states_types
- elif self.is_encoder_decoder:
- added_hidden_states = 2
- else:
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ elif self.is_encoder_decoder:
+ added_hidden_states = 2
+ else:
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
- self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
- if chunk_length is not None:
- self.assertListEqual(
- list(self_attentions[0].shape[-4:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
- )
- else:
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ if chunk_length is not None:
+ self.assertListEqual(
+ list(self_attentions[0].shape[-4:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
+ )
+ else:
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
@slow
def test_torchscript_simple(self):
@@ -728,18 +736,36 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
- input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
+ input_names = [
+ "attention_mask",
+ "decoder_attention_mask",
+ "decoder_input_ids",
+ "input_features",
+ "input_ids",
+ "input_values",
+ ]
if labels is not None:
input_names.append("labels")
+
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
- input_names = ["input_ids", "attention_mask", "token_type_ids"]
- input_ids = inputs["input_ids"]
+ input_names = [
+ "attention_mask",
+ "bbox",
+ "input_features",
+ "input_ids",
+ "input_values",
+ "pixel_values",
+ "token_type_ids",
+ "visual_feats",
+ "visual_pos",
+ ]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
@@ -752,21 +778,22 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
- input_names = filtered_inputs.keys()
+ input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
- rank = len(input_ids.shape)
- if rank not in [2, 3]:
- raise NotImplementedError(
- f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
- )
+ if (
+ isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
+ and not hasattr(model.config, "problem_type")
+ or model.config.problem_type is None
+ ):
+ model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
- except RuntimeError:
- self.fail("Couldn't trace module.")
+ except Exception as e:
+ self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
@@ -789,6 +816,40 @@ def flatten_output(output):
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
+ # Test that the model can be TorchScripted
+ try:
+ scripted = torch.jit.script(traced_model)
+ except Exception as e:
+ self.fail(f"Could not TorchScript the traced model: {e}")
+ scripted_output = scripted(**filtered_inputs)
+ scripted_output = flatten_output(scripted_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], scripted_output[i]),
+ f"scripted {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
def test_headmasking(self):
if not self.test_head_masking:
return
@@ -1447,7 +1508,12 @@ def recursive_check(tuple_object, dict_object):
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
- msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
+ f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
+ f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -1636,7 +1702,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
else:
raise ValueError(
- f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
+ "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
)
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
@@ -2066,7 +2133,7 @@ def get_current_gpu_memory_use():
memory_after_parallelization = get_current_gpu_memory_use()
# Assert that the memory use on all devices is higher than it was when loaded only on CPU
- for n in range(torch.cuda.device_count()):
+ for n in range(len(model.device_map.keys())):
self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
# Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
@@ -2142,6 +2209,115 @@ def cast_to_device(dictionary, device):
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
+ def check_device_map_is_respected(self, model, device_map):
+ for param_name, param in model.named_parameters():
+ # Find device in device_map
+ while len(param_name) > 0 and param_name not in device_map:
+ param_name = ".".join(param_name.split(".")[:-1])
+ if param_name not in device_map:
+ raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
+
+ param_device = device_map[param_name]
+ if param_device in ["cpu", "disk"]:
+ self.assertEqual(param.device, torch.device("meta"))
+ else:
+ self.assertEqual(param.device, torch.device(param_device))
+
+ @require_accelerate
+ @require_torch_gpu
+ def test_disk_offload(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ max_size = int(self.model_split_percents[0] * model_size)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ max_memory = {0: max_size, "cpu": max_size}
+ with self.assertRaises(ValueError):
+ # This errors out cause it's missing an offload folder
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+
+ new_model = model_class.from_pretrained(
+ tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
+ )
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
+ @require_accelerate
+ @require_torch_gpu
+ def test_cpu_offload(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ # We test several splits of sizes to make sure it works.
+ max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ for max_size in max_gpu_sizes:
+ max_memory = {0: max_size, "cpu": model_size * 2}
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+ # Making sure part of the model will actually end up offloaded
+ self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
+ @require_accelerate
+ @require_torch_multi_gpu
+ def test_model_parallelism(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if model_class._no_split_modules is None:
+ continue
+
+ inputs_dict = self._prepare_for_class(inputs_dict, model_class)
+ model = model_class(config).eval()
+ model = model.to(torch_device)
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_sizes(model)[""]
+ # We test several splits of sizes to make sure it works.
+ max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+
+ for max_size in max_gpu_sizes:
+ max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+ # Making sure part of the model will actually end up offloaded
+ self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
+
+ self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+ new_output = new_model(**inputs_dict)
+
+ self.assertTrue(torch.allclose(base_output[0], new_output[0]))
+
def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -2382,6 +2558,10 @@ def test_model_from_pretrained_torch_dtype(self):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
+ # test model whose first param is not of a floating type, but int
+ model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
+ self.assertEqual(model.dtype, torch.float32)
+
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
@@ -2511,6 +2691,7 @@ def test_checkpoint_sharding_from_hub(self):
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
+ @require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
@@ -2523,6 +2704,7 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self):
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
@require_usr_bin_time
+ @require_accelerate
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
@@ -2561,6 +2743,23 @@ def test_from_pretrained_low_cpu_mem_usage_measured(self):
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
+ @require_accelerate
+ @require_torch_multi_gpu
+ @slow
+ def test_model_parallelism_gpt2(self):
+ device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
+ for i in range(12):
+ device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
+
+ model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
+
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ inputs = tokenizer("Hello, my name is", return_tensors="pt")
+ output = model.generate(inputs["input_ids"].to(0))
+
+ text_output = tokenizer.decode(output[0].tolist())
+ self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
+
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index 0d38713e08d399..908d0722207378 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -25,6 +25,8 @@
from importlib import import_module
from typing import List, Tuple
+from datasets import Dataset
+
from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
@@ -505,7 +507,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
else:
raise ValueError(
- f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
+ "`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
+ f" {type(tf_outputs)} instead."
)
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
@@ -956,7 +959,10 @@ def recursive_check(tuple_object, dict_object):
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
- msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
+ msg=(
+ "Tuple and dict output are not equal. Difference:"
+ f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
+ ),
)
recursive_check(tuple_output, dict_output)
@@ -972,9 +978,10 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class)
- check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
# Not all models accept "labels" in the forward pass (yet :) )
if "labels" in inspect.signature(model.call).parameters.keys():
@@ -986,15 +993,16 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
+ if self.has_attentions:
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
- tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
- check_equivalence(
- model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
- )
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(
+ model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
+ )
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1351,7 +1359,25 @@ def test_keras_fit(self):
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
- model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
+ accuracy_classes = [
+ "ForPreTraining",
+ "ForCausalLM",
+ "ForMaskedLM",
+ "ForQuestionAnswering",
+ "ForMultipleChoice",
+ "ForSequenceClassification",
+ "ForTokenClassification",
+ "ForNextSentencePrediction",
+ "LMHeadModel",
+ ]
+ for accuracy_class in accuracy_classes:
+ if model.__class__.__name__.endswith(accuracy_class):
+ metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
+ break
+ else:
+ metrics = []
+
+ model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
@@ -1361,6 +1387,7 @@ def test_keras_fit(self):
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
+ accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
history2 = model.fit(
inputs_minus_labels,
labels,
@@ -1370,7 +1397,52 @@ def test_keras_fit(self):
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
+ accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
+ self.assertEqual(history1.history.keys(), history2.history.keys())
+ for key in history1.history.keys():
+ if not key.startswith("val_"):
+ self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
+ if metrics:
+ self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
+
+ # Make sure fit works with tf.data.Dataset and results are consistent
+ dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
+ # Pass in all samples as a batch to match other `fit` calls
+ dataset = dataset.batch(len(dataset))
+ history3 = model.fit(
+ dataset,
+ validation_data=dataset,
+ steps_per_epoch=1,
+ validation_steps=1,
+ shuffle=False,
+ )
+ val_loss3 = history3.history["val_loss"][0]
+ accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
+ self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
+ self.assertEqual(history1.history.keys(), history3.history.keys())
+ if metrics:
+ self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
+
+ def test_int64_inputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ prepared_for_class = self._prepare_for_class(
+ inputs_dict.copy(),
+ model_class,
+ return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
+ )
+ if not any(
+ [tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)]
+ ):
+ return # No integer inputs means no need for this test
+
+ prepared_for_class = {
+ key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
+ for key, tensor in prepared_for_class.items()
+ }
+ model = model_class(config)
+ model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
@@ -1459,6 +1531,56 @@ def test_model_main_input_name(self):
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
+ def test_dataset_conversion(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
+ tf_inputs_dict = {
+ key: val
+ for key, val in tf_inputs_dict.items()
+ if "head_mask" not in key and isinstance(val, tf.Tensor)
+ }
+ tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
+ input_dataset = Dataset.from_dict(tf_inputs_dict)
+ tf_dataset = model.prepare_tf_dataset(
+ input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
+ )
+ test_batch = next(iter(tf_dataset))
+ if isinstance(test_batch, tf.Tensor):
+ self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
+ else:
+ # Assert we discarded the unwanted extra column but kept everything else
+ self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
+ self.assertNotIn("extra_unwanted_column", test_batch)
+ for tensor in test_batch.values():
+ self.assertTrue(isinstance(tensor, tf.Tensor))
+ self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
+ model(test_batch, training=False)
+
+ if "labels" in inspect.signature(model_class.call).parameters.keys():
+ tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ if "labels" not in tf_inputs_dict:
+ return # This model isn't giving us labels after all, don't try training with it
+ tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
+ tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
+ input_dataset = Dataset.from_dict(tf_inputs_dict)
+ tf_dataset = model.prepare_tf_dataset(
+ input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
+ )
+ test_batch, test_batch_labels = next(iter(tf_dataset))
+ self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
+ feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
+ label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
+ # Assert we discarded the unwanted extra column but kept everything else
+ self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
+ if isinstance(test_batch, dict):
+ self.assertNotIn("extra_unwanted_column", test_batch)
+ if isinstance(test_batch_labels, dict):
+ self.assertNotIn("extra_unwanted_column", test_batch_labels)
+ model.compile(optimizer="sgd", run_eagerly=True)
+ model.train_on_batch(test_batch, test_batch_labels)
+
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index fe16e5e1cd524e..050f9a6f5d876b 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -51,6 +51,7 @@
from transformers.testing_utils import (
PASS,
USER,
+ check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
is_staging_test,
@@ -1005,7 +1006,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1016,7 +1018,8 @@ def test_maximum_encoding_length_single_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1131,7 +1134,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -1142,7 +1146,8 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(cm.records), 1)
self.assertTrue(
cm.records[0].message.startswith(
- "Token indices sequence length is longer than the specified maximum sequence length for this model"
+ "Token indices sequence length is longer than the specified maximum sequence length"
+ " for this model"
)
)
@@ -2401,13 +2406,15 @@ def test_prepare_seq2seq_batch(self):
# Longer text that will definitely require truncation.
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
- " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
+ " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
+ " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
+ " will only worsen the violence and misery for millions of people.",
]
tgt_text = [
"Åeful ONU declarÄ cÄ nu existÄ o soluÅ£ie militarÄ Ć®n Siria",
- "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al Rusiei "
- 'pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi cÄ noi arme nu '
- "vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
+ "Secretarul General Ban Ki-moon declarÄ cÄ rÄspunsul sÄu la intensificarea sprijinului militar al"
+ ' Rusiei pentru Siria este cÄ "nu existÄ o soluÅ£ie militarÄ" la conflictul de aproape cinci ani Åi'
+ " cÄ noi arme nu vor face decĆ¢t sÄ Ć®nrÄutÄÅ£eascÄ violenÅ£ele Åi mizeria pentru milioane de oameni.",
]
try:
batch = tokenizer.prepare_seq2seq_batch(
@@ -3319,6 +3326,11 @@ def test_save_pretrained(self):
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
+ # make sure that all ".json" files are saved in the correct format
+ for file_path in tokenizer_r_files + tokenizer_p_files:
+ if os.path.exists(file_path) and file_path.endswith(".json"):
+ check_json_file_has_correct_format(file_path)
+
# Checks it save with the same files + the tokenizer.json file for the fast one
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
@@ -3658,11 +3670,9 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
break
self.assertTrue(
find,
- (
- f"'{new_special_token_str}' doesn't appear in the list "
- f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
- f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}"
- ),
+ f"'{new_special_token_str}' doesn't appear in the list "
+ f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as "
+ f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}",
)
elif special_token not in special_tokens_map:
# The special token must appear identically in the list of the new tokenizer.
@@ -3725,7 +3735,8 @@ def test_tokenizer_mismatch_warning(self):
finally:
self.assertTrue(
cm.records[0].message.startswith(
- "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
+ "The tokenizer class you load from this checkpoint is not the same type as the class"
+ " this function is called from."
)
)
diff --git a/tests/tokenization/test_tokenization_fast.py b/tests/tokenization/test_tokenization_fast.py
index 9e5ad178e53a79..da98d17d7722f5 100644
--- a/tests/tokenization/test_tokenization_fast.py
+++ b/tests/tokenization/test_tokenization_fast.py
@@ -39,6 +39,7 @@ def setUp(self):
self.test_rust_tokenizer = True
model_paths = ["robot-test/dummy-tokenizer-fast", "robot-test/dummy-tokenizer-wordlevel"]
+ self.bytelevel_bpe_model_name = "SaulLu/dummy-tokenizer-bytelevel-bpe"
# Inclusion of 2 tokenizers to test different types of models (Unigram and WordLevel for the moment)
self.tokenizers_list = [(PreTrainedTokenizerFast, model_path, {}) for model_path in model_paths]
@@ -99,6 +100,15 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
shutil.rmtree(self.tmpdirname)
self.tmpdirname = tmpdirname_orig
+ def test_training_new_tokenizer_with_bytelevel(self):
+ tokenizer = self.rust_tokenizer_class.from_pretrained(self.bytelevel_bpe_model_name)
+
+ toy_text_iterator = ("a" for _ in range(1000))
+ new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)
+
+ encoding_ids = new_tokenizer.encode("aš¤")
+ self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
+
@require_tokenizers
class TokenizerVersioningTest(unittest.TestCase):
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 1d80a85f0ef5a4..fef392aefff6d3 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -15,11 +15,13 @@
import dataclasses
import gc
+import json
import math
import os
import random
import re
import subprocess
+import sys
import tempfile
import time
import unittest
@@ -48,6 +50,7 @@
get_gpu_count,
get_tests_dir,
is_staging_test,
+ require_intel_extension_for_pytorch,
require_optuna,
require_ray,
require_sentencepiece,
@@ -60,12 +63,13 @@
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
+ require_torchdynamo,
require_wandb,
slow,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
-from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
+from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer
@@ -161,11 +165,12 @@ def __call__(self, eval_pred):
class RegressionModelConfig(PretrainedConfig):
- def __init__(self, a=0, b=0, double_output=False, **kwargs):
+ def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
self.double_output = double_output
+ self.random_torch = random_torch
self.hidden_size = 1
@@ -263,14 +268,18 @@ def __init__(self, config):
super().__init__(config)
self.a = nn.Parameter(torch.tensor(config.a).float())
self.b = nn.Parameter(torch.tensor(config.b).float())
+ self.random_torch = config.random_torch
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
- torch_rand = torch.randn(1).squeeze()
+ if self.random_torch:
+ torch_rand = torch.randn(1).squeeze()
np_rand = np.random.rand()
rand_rand = random.random()
- y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand)
+ if self.random_torch:
+ y += 0.05 * torch_rand
+ y += 0.05 * torch.tensor(np_rand + rand_rand)
if labels is None:
return (y,)
@@ -376,6 +385,25 @@ def check_trainer_state_are_the_same(self, trainer_state, trainer_state1):
_ = log1.pop(key, None)
self.assertEqual(log, log1)
+ def convert_to_sharded_checkpoint(self, folder):
+ # Converts a checkpoint of a regression model to a sharded checkpoint.
+ state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME))
+ os.remove(os.path.join(folder, WEIGHTS_NAME))
+ keys = list(state_dict.keys())
+
+ shard_files = [
+ WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys))
+ ]
+ index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
+
+ save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+
+ for param_name, shard_file in zip(keys, shard_files):
+ torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
+
@require_torch
@require_sentencepiece
@@ -613,6 +641,29 @@ def test_number_of_steps_in_training(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
+ @require_torch_bf16
+ @require_intel_extension_for_pytorch
+ def test_number_of_steps_in_training_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ # Regular training has n_epochs * len(train_dl) steps
+ trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
+
+ # Check passing num_train_epochs works (and a float version too):
+ trainer = get_regression_trainer(
+ learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
+
+ # If we pass a max_steps, num_train_epochs is ignored
+ trainer = get_regression_trainer(
+ learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ train_output = trainer.train()
+ self.assertEqual(train_output.global_step, 10)
+
def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
@@ -793,6 +844,101 @@ def test_evaluate(self):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+ def test_evaluate_with_jit(self):
+ trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With logits preprocess
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ compute_metrics=AlmostAccuracy(),
+ preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
+ jit_mode_eval=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ @require_torch_bf16
+ @require_intel_extension_for_pytorch
+ def test_evaluate_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, no_cuda=True
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ use_ipex=True,
+ eval_len=66,
+ compute_metrics=AlmostAccuracy(),
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
+ # With logits preprocess
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ use_ipex=True,
+ compute_metrics=AlmostAccuracy(),
+ preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ results = trainer.evaluate()
+
+ x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
+ pred = 1.5 * x + 2.5
+ expected_loss = ((pred - y) ** 2).mean()
+ self.assertAlmostEqual(results["eval_loss"], expected_loss)
+ expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
+ self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
+
def test_predict(self):
trainer = get_regression_trainer(a=1.5, b=2.5)
preds = trainer.predict(trainer.eval_dataset).predictions
@@ -825,6 +971,85 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+ def test_predict_with_jit(self):
+ trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With more than one output of the model
+ trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+
+ # With more than one output/label of the model
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"], jit_mode_eval=True
+ )
+ outputs = trainer.predict(trainer.eval_dataset)
+ preds = outputs.predictions
+ labels = outputs.label_ids
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+ self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
+ self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+
+ @require_torch_bf16
+ @require_intel_extension_for_pytorch
+ def test_predict_with_ipex(self):
+ for mix_bf16 in [True, False]:
+ trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With a number of elements not a round multiple of the batch size
+ trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, no_cuda=True)
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
+
+ # With more than one output of the model
+ trainer = get_regression_trainer(
+ a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, no_cuda=True
+ )
+ preds = trainer.predict(trainer.eval_dataset).predictions
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+
+ # With more than one output/label of the model
+ trainer = get_regression_trainer(
+ a=1.5,
+ b=2.5,
+ double_output=True,
+ label_names=["labels", "labels_2"],
+ use_ipex=True,
+ bf16=mix_bf16,
+ no_cuda=True,
+ )
+ outputs = trainer.predict(trainer.eval_dataset)
+ preds = outputs.predictions
+ labels = outputs.label_ids
+ x = trainer.eval_dataset.x
+ self.assertEqual(len(preds), 2)
+ self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
+ self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
+ self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
+ self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
+
def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)
@@ -996,33 +1221,95 @@ def test_can_resume_training(self):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
- @require_torch_non_multi_gpu
def test_resume_training_with_randomness(self):
- # This test will fail flakily for more than 1 GPUs since the result will be slightly more different
- # TODO: investigate why it fails for 2 GPUs?
+ # For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
+ # in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
+ # GPU 0 will call first and sometimes GPU 1).
+ random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()
- config = RegressionModelConfig(a=0, b=2)
- model = RegressionRandomPreTrainedModel(config)
+ with self.subTest("Test every step"):
+ config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
+ model = RegressionRandomPreTrainedModel(config)
- tmp_dir = self.get_auto_remove_tmp_dir()
- args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
- trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
+ trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
- trainer.train()
- (a, b) = trainer.model.a.item(), trainer.model.b.item()
+ trainer.train()
+ (a, b) = trainer.model.a.item(), trainer.model.b.item()
- model = RegressionRandomPreTrainedModel(config)
- trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
- trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
- (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
+ model = RegressionRandomPreTrainedModel(config)
+ trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
+ trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
+ (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
+
+ self.assertAlmostEqual(a, a1, delta=1e-8)
+ self.assertAlmostEqual(b, b1, delta=1e-8)
+
+ with self.subTest("Test every epoch"):
+ config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
+ model = RegressionRandomPreTrainedModel(config)
+
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ args = RegressionTrainingArguments(tmp_dir, save_strategy="epoch", learning_rate=0.1)
+ trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
+
+ trainer.train()
+ (a, b) = trainer.model.a.item(), trainer.model.b.item()
+
+ model = RegressionRandomPreTrainedModel(config)
+ trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
+
+ checkpoints = [d for d in os.listdir(tmp_dir) if d.startswith("checkpoint-")]
+ # There should be one checkpoint per epoch.
+ self.assertEqual(len(checkpoints), 3)
+ checkpoint_dir = sorted(checkpoints, key=lambda x: int(x.replace("checkpoint-", "")))[0]
+
+ trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, checkpoint_dir))
+ (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
+
+ self.assertAlmostEqual(a, a1, delta=1e-8)
+ self.assertAlmostEqual(b, b1, delta=1e-8)
- self.assertAlmostEqual(a, a1, delta=1e-8)
- self.assertAlmostEqual(b, b1, delta=1e-8)
+ @slow
+ @require_torch_non_multi_gpu
+ def test_auto_batch_size_finder(self):
+
+ if torch.cuda.is_available():
+ torch.backends.cudnn.deterministic = True
+
+ SRC_DIR = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "..", "..", "examples", "pytorch", "text-classification")
+ )
+ sys.path.append(SRC_DIR)
+ import run_glue
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ testargs = f"""
+ run_glue.py
+ --model_name_or_path distilbert-base-uncased
+ --task_name mrpc
+ --do_train
+ --do_eval
+ --max_seq_len 128
+ --per_device_train_batch_size 4096
+ --learning_rate 2e-5
+ --num_train_epochs 1
+ --output_dir {tmpdir}
+ --auto_find_batch_size 0
+ """.split()
+ with self.assertRaises(RuntimeError):
+ with patch.object(sys, "argv", testargs):
+ run_glue.main()
+
+ testargs[-1] = "1"
+ with patch.object(sys, "argv", testargs):
+ run_glue.main()
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self):
@@ -1038,6 +1325,31 @@ def test_training_with_resume_from_checkpoint_false(self):
trainer.train(resume_from_checkpoint=False)
+ @require_torch_up_to_2_gpus
+ def test_resume_training_with_shard_checkpoint(self):
+ # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
+ # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
+ # won't be the same since the training dataloader is shuffled).
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
+ trainer.train()
+ (a, b) = trainer.model.a.item(), trainer.model.b.item()
+ state = dataclasses.asdict(trainer.state)
+
+ checkpoint = os.path.join(tmpdir, "checkpoint-5")
+ self.convert_to_sharded_checkpoint(checkpoint)
+
+ # Reinitialize trainer
+ trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
+
+ trainer.train(resume_from_checkpoint=checkpoint)
+ (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
+ state1 = dataclasses.asdict(trainer.state)
+ self.assertEqual(a, a1)
+ self.assertEqual(b, b1)
+ self.check_trainer_state_are_the_same(state, state1)
+
@require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
@@ -1216,7 +1528,8 @@ def test_trainer_eval_lm(self):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
- train_dataset = SampleIterableDataset()
+ # Adding one column not used by the model should have no impact
+ train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
@@ -1250,7 +1563,8 @@ def test_training_finite_iterable_dataset(self):
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
- eval_dataset = SampleIterableDataset()
+ # Adding one column not used by the model should have no impact
+ eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
@@ -1287,7 +1601,8 @@ def test_predict_iterable_dataset(self):
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
- test_dataset = SampleIterableDataset(length=66)
+ # Adding one column not used by the model should have no impact
+ test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
@@ -1435,7 +1750,7 @@ def test_fp16_full_eval(self):
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
- # 1. with mem metrics enabled
+ # 1. with fp16_full_eval disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
@@ -1456,7 +1771,7 @@ def test_fp16_full_eval(self):
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
- # 2. with mem metrics disabled
+ # 2. with fp16_full_eval enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, fp16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
fp16_init = metrics["init_mem_gpu_alloc_delta"]
@@ -1478,6 +1793,100 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
+ @require_torch_non_multi_gpu
+ @require_torchdynamo
+ def test_torchdynamo_full_eval(self):
+ # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
+ n_gpus = get_gpu_count()
+
+ bs = 8
+ eval_len = 16 * n_gpus
+ # make the params are somewhat big so that there will be enough RAM consumed to be able to
+ # measure things. We should get about 64KB for a+b in fp32
+ a = torch.ones(1000, bs) + 0.001
+ b = torch.ones(1000, bs) - 0.001
+
+ # 1. Default - without TorchDynamo
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
+ metrics = trainer.evaluate()
+ original_eval_loss = metrics["eval_loss"]
+ del trainer
+
+ # 2. TorchDynamo eager
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
+ metrics = trainer.evaluate()
+ self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
+ del trainer
+
+ # 3. TorchDynamo nvfuser
+ trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
+ metrics = trainer.evaluate()
+ self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
+
+ @require_torch_non_multi_gpu
+ @require_torchdynamo
+ def test_torchdynamo_memory(self):
+ # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
+ class CustomTrainer(Trainer):
+ def compute_loss(self, model, inputs, return_outputs=False):
+ x = inputs["x"]
+ output = model(x)
+ if self.args.n_gpu == 1:
+ return output.mean()
+ return output
+
+ class MyModule(torch.nn.Module):
+ """Simple module that does aggressive fusion"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ for _ in range(20):
+ x = torch.nn.functional.relu(x)
+ return x
+
+ mod = MyModule()
+
+ # 1. Default - without TorchDynamo
+ a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
+ a.grad = None
+ trainer = CustomTrainer(model=mod)
+ # warmup
+ for _ in range(10):
+ orig_loss = trainer.training_step(mod, {"x": a})
+
+ torch.cuda.reset_peak_memory_stats()
+ orig_loss = trainer.training_step(mod, {"x": a})
+ orig_peak_mem = torch.cuda.max_memory_allocated()
+ del trainer
+
+ # Reset the peak for another measurement
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ # 2. TorchDynamo nvfuser
+ a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
+ a.grad = None
+ args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
+ trainer = CustomTrainer(model=mod, args=args)
+ # warmup
+ for _ in range(10):
+ loss = trainer.training_step(mod, {"x": a})
+
+ torch.cuda.reset_peak_memory_stats()
+ loss = trainer.training_step(mod, {"x": a})
+ peak_mem = torch.cuda.max_memory_allocated()
+ del trainer
+
+ # Functional check
+ self.assertAlmostEqual(loss, orig_loss)
+
+ # AOT Autograd recomputaion and nvfuser recomputation optimization
+ # aggressively fuses the operations and reduce the memory footprint.
+ self.assertGreater(orig_peak_mem, peak_mem * 2)
+
@require_torch_gpu
@require_torch_bf16
def test_bf16_full_eval(self):
@@ -1495,7 +1904,7 @@ def test_bf16_full_eval(self):
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
- # 1. with mem metrics enabled
+ # 1. with bf16_full_eval disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
metrics = trainer.evaluate()
del trainer
@@ -1516,7 +1925,7 @@ def test_bf16_full_eval(self):
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
- # 2. with mem metrics disabled
+ # 2. with bf16_full_eval enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False)
metrics = trainer.evaluate()
bf16_init = metrics["init_mem_gpu_alloc_delta"]
@@ -1785,6 +2194,7 @@ def test_hyperparameter_search_ray_client(self):
self.ray_hyperparameter_search()
+@slow
@require_torch
@require_sigopt
class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py
index 7710892d8d79fc..869d19b0a1e60f 100644
--- a/tests/trainer/test_trainer_utils.py
+++ b/tests/trainer/test_trainer_utils.py
@@ -18,7 +18,9 @@
import numpy as np
-from transformers.testing_utils import require_torch
+from transformers.data.data_collator import default_data_collator
+from transformers.testing_utils import require_accelerate, require_torch
+from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
from transformers.utils import is_torch_available
@@ -39,6 +41,8 @@
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
+ numpy_pad_and_concatenate,
+ torch_pad_and_concatenate,
)
class TstLayer(nn.Module):
@@ -420,3 +424,76 @@ def test_shard_sampler(self):
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
+
+ @require_accelerate
+ def test_executable_batch_size(self):
+ batch_sizes = []
+
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=True)
+ def mock_training_loop_function(batch_size):
+ nonlocal batch_sizes
+ batch_sizes.append(batch_size)
+ if batch_size > 16:
+ raise RuntimeError("CUDA out of memory.")
+
+ mock_training_loop_function()
+ self.assertEqual(batch_sizes, [64, 32, 16])
+
+ @require_accelerate
+ def test_executable_batch_size_no_search(self):
+ batch_sizes = []
+
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
+ def mock_training_loop_function(batch_size):
+ nonlocal batch_sizes
+ batch_sizes.append(batch_size)
+
+ mock_training_loop_function()
+ self.assertEqual(batch_sizes, [64])
+
+ @require_accelerate
+ def test_executable_batch_size_with_error(self):
+ @find_executable_batch_size(starting_batch_size=64, auto_find_batch_size=False)
+ def mock_training_loop_function(batch_size):
+ raise RuntimeError("CUDA out of memory.")
+
+ with self.assertRaises(RuntimeError) as cm:
+ mock_training_loop_function()
+ self.assertEqual("CUDA out of memory", cm.args[0])
+
+ def test_pad_and_concatenate_with_1d(self):
+ """Tests whether pad_and_concatenate works with scalars."""
+ array1 = 1.0
+ array2 = 2.0
+ result = numpy_pad_and_concatenate(array1, array2)
+ self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result))
+
+ tensor1 = torch.tensor(1.0)
+ tensor2 = torch.tensor(2.0)
+ result = torch_pad_and_concatenate(tensor1, tensor2)
+ self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0])))
+
+ def test_remove_columns_collator(self):
+ class MockLogger:
+ def __init__(self) -> None:
+ self.called = 0
+
+ def info(self, msg):
+ self.called += 1
+ self.last_msg = msg
+
+ data_batch = [
+ {"col1": 1, "col2": 2, "col3": 3},
+ {"col1": 1, "col2": 2, "col3": 3},
+ ]
+ logger = MockLogger()
+ remove_columns_collator = RemoveColumnsCollator(
+ default_data_collator, ["col1", "col2"], logger, "model", "training"
+ )
+
+ self.assertNotIn("col3", remove_columns_collator(data_batch))
+ # check that the logging message is printed out only once
+ remove_columns_collator(data_batch)
+ remove_columns_collator(data_batch)
+ self.assertEqual(logger.called, 1)
+ self.assertIn("col3", logger.last_msg)
diff --git a/tests/utils/test_cli.py b/tests/utils/test_cli.py
index 1e5ba4fa27c9d0..f39aa600679a44 100644
--- a/tests/utils/test_cli.py
+++ b/tests/utils/test_cli.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+import shutil
import unittest
from unittest.mock import patch
-from transformers.testing_utils import CaptureStd
+from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test
class CLITest(unittest.TestCase):
@@ -30,3 +32,16 @@ def test_cli_env(self):
self.assertIn("Python version", cs.out)
self.assertIn("Platform", cs.out)
self.assertIn("Using distributed or parallel set-up in script?", cs.out)
+
+ @is_pt_tf_cross_test
+ @patch(
+ "sys.argv", ["fakeprogrampath", "pt-to-tf", "--model-name", "hf-internal-testing/tiny-random-gptj", "--no-pr"]
+ )
+ def test_cli_pt_to_tf(self):
+ import transformers.commands.transformers_cli
+
+ shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
+ transformers.commands.transformers_cli.main()
+
+ # The original repo has no TF weights -- if they exist, they were created by the CLI
+ self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
diff --git a/tests/utils/test_convert_slow_tokenizer.py b/tests/utils/test_convert_slow_tokenizer.py
index 087fbd5053b23c..8655ea4602e76a 100644
--- a/tests/utils/test_convert_slow_tokenizer.py
+++ b/tests/utils/test_convert_slow_tokenizer.py
@@ -13,8 +13,8 @@ class FakeOriginalTokenizer:
class ConvertSlowTokenizerTest(unittest.TestCase):
def test_spm_converter_bytefallback_warning(self):
- spm_model_file_without_bytefallback = f"{get_tests_dir()}/fixtures/test_sentencepiece.model"
- spm_model_file_with_bytefallback = f"{get_tests_dir()}/fixtures/test_sentencepiece_with_bytefallback.model"
+ spm_model_file_without_bytefallback = get_tests_dir("fixtures/test_sentencepiece.model")
+ spm_model_file_with_bytefallback = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.model")
original_tokenizer_without_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_without_bytefallback)
@@ -28,9 +28,7 @@ def test_spm_converter_bytefallback_warning(self):
_ = SpmConverter(original_tokenizer_with_bytefallback)
self.assertEqual(len(w), 1)
self.assertIn(
- (
- "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
- " which is not implemented in the fast tokenizers."
- ),
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers.",
str(w[0].message),
)
diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py
index 75c4f19caa1dce..19adfe21dd4bf6 100644
--- a/tests/utils/test_file_utils.py
+++ b/tests/utils/test_file_utils.py
@@ -99,12 +99,20 @@ def test_file_not_found(self):
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
- def test_model_not_found(self):
- # Invalid model file.
+ def test_model_not_found_not_authenticated(self):
+ # Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
- with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
+ with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
_ = get_from_cache(url)
+ @unittest.skip("No authentication when testing against prod")
+ def test_model_not_found_authenticated(self):
+ # Invalid model id.
+ url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
+ with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
+ _ = get_from_cache(url, use_auth_token="hf_sometoken")
+ # ^ TODO - if we decide to unskip this: use a real / functional token
+
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py
new file mode 100644
index 00000000000000..6fbdbee4036070
--- /dev/null
+++ b/tests/utils/test_generic.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.utils import flatten_dict
+
+
+class GenericTester(unittest.TestCase):
+ def test_flatten_dict(self):
+ input_dict = {
+ "task_specific_params": {
+ "summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
+ "summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
+ "summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
+ }
+ }
+ expected_dict = {
+ "task_specific_params.summarization.length_penalty": 1.0,
+ "task_specific_params.summarization.max_length": 128,
+ "task_specific_params.summarization.min_length": 12,
+ "task_specific_params.summarization.num_beams": 4,
+ "task_specific_params.summarization_cnn.length_penalty": 2.0,
+ "task_specific_params.summarization_cnn.max_length": 142,
+ "task_specific_params.summarization_cnn.min_length": 56,
+ "task_specific_params.summarization_cnn.num_beams": 4,
+ "task_specific_params.summarization_xsum.length_penalty": 1.0,
+ "task_specific_params.summarization_xsum.max_length": 62,
+ "task_specific_params.summarization_xsum.min_length": 11,
+ "task_specific_params.summarization_xsum.num_beams": 6,
+ }
+
+ self.assertEqual(flatten_dict(input_dict), expected_dict)
diff --git a/tests/utils/test_model_card.py b/tests/utils/test_model_card.py
index 1004642a92a2a6..7d0e8795e0aab9 100644
--- a/tests/utils/test_model_card.py
+++ b/tests/utils/test_model_card.py
@@ -38,7 +38,10 @@ def setUp(self):
},
"training_data": {
"Dataset": "English Wikipedia dump dated 2018-12-01",
- "Preprocessing": "Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf",
+ "Preprocessing": (
+ "Using SentencePiece vocabulary of size 52k tokens. See details on"
+ " https://arxiv.org/pdf/1810.03993.pdf"
+ ),
},
"quantitative_analyses": {"BLEU": 55.1, "ROUGE-1": 76},
}
diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py
index 8edfc8eab02d4c..abdce686835077 100644
--- a/tests/utils/test_modeling_tf_core.py
+++ b/tests/utils/test_modeling_tf_core.py
@@ -205,7 +205,7 @@ def test_saved_model_creation_extended(self):
@slow
def test_mixed_precision(self):
- tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+ tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -216,7 +216,7 @@ def test_mixed_precision(self):
self.assertIsNotNone(outputs)
- tf.keras.mixed_precision.experimental.set_policy("float32")
+ tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):
diff --git a/tests/utils/test_utils_check_copies.py b/tests/utils/test_utils_check_copies.py
index 7c81df714cb955..57cecf6653ff8e 100644
--- a/tests/utils/test_utils_check_copies.py
+++ b/tests/utils/test_utils_check_copies.py
@@ -125,9 +125,48 @@ def test_is_copy_consistent(self):
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
- md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning."
- localized_md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
- converted_md_list_sample = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (ę„čŖ HuggingFace) 伓éč®ŗę [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) ē± Victor Sanh, Lysandre Debut and Thomas Wolf ååøć The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (ę„čŖ Google Research/Stanford University) 伓éč®ŗę [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) ē± Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning ååøć\n"
+ md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
+ " Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
+ " Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
+ " Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1."
+ " **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace),"
+ " released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
+ " lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same"
+ " method has been applied to compress GPT2 into"
+ " [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
+ " [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
+ " Multilingual BERT into"
+ " [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
+ " version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)**"
+ " (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders"
+ " as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang"
+ " Luong, Quoc V. Le, Christopher D. Manning."
+ )
+ localized_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
+ converted_md_list_sample = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n1."
+ " **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (ę„čŖ HuggingFace) 伓éč®ŗę"
+ " [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
+ " lighter](https://arxiv.org/abs/1910.01108) ē± Victor Sanh, Lysandre Debut and Thomas Wolf ååøć The same"
+ " method has been applied to compress GPT2 into"
+ " [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
+ " [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
+ " Multilingual BERT into"
+ " [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
+ " version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (ę„čŖ"
+ " Google Research/Stanford University) 伓éč®ŗę [ELECTRA: Pre-training text encoders as discriminators rather"
+ " than generators](https://arxiv.org/abs/2003.10555) ē± Kevin Clark, Minh-Thang Luong, Quoc V. Le,"
+ " Christopher D. Manning ååøć\n"
+ )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
@@ -143,9 +182,24 @@ def test_convert_to_localized_md(self):
# Check whether the number of models is equal to README.md after conversion.
self.assertTrue(num_models_equal)
- link_changed_md_list = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
- link_unchanged_md_list = "1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
- converted_md_list_sample = "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ link_changed_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
+ " Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
+ " Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
+ " Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
+ )
+ link_unchanged_md_list = (
+ "1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (ę„čŖ Google Research and"
+ " the Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
+ converted_md_list_sample = (
+ "1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (ę„čŖ Google Research and the"
+ " Toyota Technological Institute at Chicago) 伓éč®ŗę [ALBERT: A Lite BERT for Self-supervised Learning of"
+ " Language Representations](https://arxiv.org/abs/1909.11942), ē± Zhenzhong Lan, Mingda Chen, Sebastian"
+ " Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut ååøć\n"
+ )
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
new file mode 100644
index 00000000000000..382f42bfe159d5
--- /dev/null
+++ b/utils/check_config_docstrings.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+import re
+
+
+# All paths are set with the intent you should run this script from the root of the repo with the command
+# python utils/check_config_docstrings.py
+PATH_TO_TRANSFORMERS = "src/transformers"
+
+
+# This is to make sure the transformers module imported is the one in the repo.
+spec = importlib.util.spec_from_file_location(
+ "transformers",
+ os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
+ submodule_search_locations=[PATH_TO_TRANSFORMERS],
+)
+transformers = spec.loader.load_module()
+
+CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
+
+# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
+# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)`
+_re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
+
+
+CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
+ "CLIPConfig",
+ "DecisionTransformerConfig",
+ "EncoderDecoderConfig",
+ "RagConfig",
+ "SpeechEncoderDecoderConfig",
+ "VisionEncoderDecoderConfig",
+ "VisionTextDualEncoderConfig",
+}
+
+
+def check_config_docstrings_have_checkpoints():
+ configs_without_checkpoint = []
+
+ for config_class in list(CONFIG_MAPPING.values()):
+ checkpoint_found = False
+
+ # source code of `config_class`
+ config_source = inspect.getsource(config_class)
+ checkpoints = _re_checkpoint.findall(config_source)
+
+ for checkpoint in checkpoints:
+ # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
+ # For example, `('bert-base-uncased', 'https://huggingface.co/bert-base-uncased')`
+ ckpt_name, ckpt_link = checkpoint
+
+ # verify the checkpoint name corresponds to the checkpoint link
+ ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
+ if ckpt_link == ckpt_link_from_name:
+ checkpoint_found = True
+ break
+
+ name = config_class.__name__
+ if not checkpoint_found and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
+ configs_without_checkpoint.append(name)
+
+ if len(configs_without_checkpoint) > 0:
+ message = "\n".join(sorted(configs_without_checkpoint))
+ raise ValueError(f"The following configurations don't contain any valid checkpoint:\n{message}")
+
+
+if __name__ == "__main__":
+ check_config_docstrings_have_checkpoints()
diff --git a/utils/check_copies.py b/utils/check_copies.py
index 5363fd1ff338c7..0f0c45ead59bcd 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -15,6 +15,7 @@
import argparse
import glob
+import importlib.util
import os
import re
@@ -40,26 +41,47 @@
"README.md": {
"start_prompt": "š¤ Transformers currently provides the following architectures",
"end_prompt": "1. Want to contribute a new model?",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
"README_zh-hans.md": {
"start_prompt": "š¤ Transformers ē®åęÆęå¦äøēę¶ę",
"end_prompt": "1. ę³č¦č“”ē®ę°ēęØ”åļ¼",
- "format_model_list": "**[{title}]({model_link})** (ę„čŖ {paper_affiliations}) 伓éč®ŗę {paper_title_link} ē± {paper_authors} ååøć{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (ę„čŖ {paper_affiliations}) 伓éč®ŗę {paper_title_link} ē± {paper_authors}"
+ " ååøć{supplements}"
+ ),
},
"README_zh-hant.md": {
"start_prompt": "š¤ Transformers ē®åęÆę“仄äøēę¶ę§",
"end_prompt": "1. ę³č¦č²¢ē»ę°ēęØ”åļ¼",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
"README_ko.md": {
"start_prompt": "š¤ Transformersė ė¤ģ ėŖØėøė¤ģ ģ ź³µķ©ėė¤",
"end_prompt": "1. ģė”ģ“ ėŖØėøģ ģ¬ė¦¬ź³ ģ¶ėģ?",
- "format_model_list": "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by {paper_authors}.{supplements}",
+ "format_model_list": (
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
+ " {paper_authors}.{supplements}"
+ ),
},
}
+# This is to make sure the transformers module imported is the one in the repo.
+spec = importlib.util.spec_from_file_location(
+ "transformers",
+ os.path.join(TRANSFORMERS_PATH, "__init__.py"),
+ submodule_search_locations=[TRANSFORMERS_PATH],
+)
+transformers_module = spec.loader.load_module()
+
+
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
@@ -130,7 +152,7 @@ def blackify(code):
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
- mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
+ mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119, preview=True)
result = black.format_str(code, mode=mode)
result, _ = style_docstrings_in_code(result)
return result[len("class Bla:\n") :] if has_indent else result
@@ -300,8 +322,6 @@ def _rep(match):
# This regex is used to synchronize link.
_re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
- num_models_equal = True
-
if len(localized_model_list) == 0:
localized_model_index = {}
else:
@@ -313,10 +333,16 @@ def _rep(match):
except AttributeError:
raise AttributeError("A model name in localized READMEs cannot be recognized.")
+ model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
+
+ # We exclude keys in localized README not in the main one.
+ readmes_match = not any([k not in model_keys for k in localized_model_index])
+ localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
+
for model in model_list.strip().split("\n"):
title, model_link = _re_capture_title_link.search(model).groups()
if title not in localized_model_index:
- num_models_equal = False
+ readmes_match = False
# Add an anchor white space behind a model description string for regex.
# If metadata cannot be captured, the English version will be directly copied.
localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
@@ -328,7 +354,7 @@ def _rep(match):
sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
- return num_models_equal, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
+ return readmes_match, "\n".join(map(lambda x: x[1], sorted_index)) + "\n"
def convert_readme_to_index(model_list):
@@ -368,7 +394,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers")
- new_readme = readme.replace(
+ new_readme = new_readme.replace(
"https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main"
)
if new_readme != readme:
@@ -400,9 +426,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
_format_model_list = value["format_model_list"]
localized_md_list = get_model_list(filename, _start_prompt, _end_prompt)
- num_models_equal, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
+ readmes_match, converted_md_list = convert_to_localized_md(md_list, localized_md_list, _format_model_list)
- converted_md_lists.append((filename, num_models_equal, converted_md_list, _start_prompt, _end_prompt))
+ converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
converted_md_list = convert_readme_to_index(md_list)
if converted_md_list != index_list:
@@ -416,7 +442,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
for converted_md_list in converted_md_lists:
- filename, num_models_equal, converted_md, _start_prompt, _end_prompt = converted_md_list
+ filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
if filename == "README.md":
continue
@@ -426,17 +452,94 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
with open(os.path.join(REPO_PATH, filename), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [converted_md] + lines[end_index:])
- elif not num_models_equal:
+ elif not readmes_match:
raise ValueError(
f"The model list in the README changed and the list in `{filename}` has not been updated. Run "
"`make fix-copies` to fix this."
)
+SPECIAL_MODEL_NAMES = {
+ "Bert Generation": "BERT For Sequence Generation",
+ "BigBird": "BigBird-RoBERTa",
+ "Data2VecAudio": "Data2Vec",
+ "Data2VecText": "Data2Vec",
+ "Data2VecVision": "Data2Vec",
+ "Marian": "MarianMT",
+ "OpenAI GPT-2": "GPT-2",
+ "OpenAI GPT": "GPT",
+ "Perceiver": "Perceiver IO",
+ "ViT": "Vision Transformer (ViT)",
+}
+
+# Update this list with the models that shouldn't be in the README. This only concerns modular models or those who do
+# not have an associated paper.
+MODELS_NOT_IN_README = [
+ "BertJapanese",
+ "Encoder decoder",
+ "FairSeq Machine-Translation",
+ "HerBERT",
+ "RetriBERT",
+ "Speech Encoder decoder",
+ "Speech2Text",
+ "Speech2Text2",
+ "Vision Encoder decoder",
+ "VisionTextDualEncoder",
+]
+
+
+README_TEMPLATE = (
+ "1. **[{model_name}](https://huggingface.co/docs/transformers/model_doc/{model_type})** (from ) "
+ "released with the paper []() by ."
+)
+
+
+def check_readme(overwrite=False):
+ info = LOCALIZED_READMES["README.md"]
+ models, start_index, end_index, lines = _find_text_in_file(
+ os.path.join(REPO_PATH, "README.md"),
+ info["start_prompt"],
+ info["end_prompt"],
+ )
+ models_in_readme = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in models.strip().split("\n")]
+
+ model_names_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
+ absents = [
+ (key, name)
+ for key, name in model_names_mapping.items()
+ if SPECIAL_MODEL_NAMES.get(name, name) not in models_in_readme
+ ]
+ # Remove exceptions
+ absents = [(key, name) for key, name in absents if name not in MODELS_NOT_IN_README]
+ if len(absents) > 0 and not overwrite:
+ print(absents)
+ raise ValueError(
+ "The main README doesn't contain all models, run `make fix-copies` to fill it with the missing model(s)"
+ " then complete the generated entries.\nIf the model is not supposed to be in the main README, add it to"
+ " the list `MODELS_NOT_IN_README` in utils/check_copies.py.\nIf it has a different name in the repo than"
+ " in the README, map the correspondence in `SPECIAL_MODEL_NAMES` in utils/check_copies.py."
+ )
+
+ new_models = [README_TEMPLATE.format(model_name=name, model_type=key) for key, name in absents]
+
+ all_models = models.strip().split("\n") + new_models
+ all_models = sorted(all_models, key=lambda x: re.search(r"\*\*\[([^\]]*)", x).groups()[0].lower())
+ all_models = "\n".join(all_models) + "\n"
+
+ if all_models != models:
+ if overwrite:
+ print("Fixing the main README.")
+ with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
+ f.writelines(lines[:start_index] + [all_models] + lines[end_index:])
+ else:
+ raise ValueError("The main README model list is not properly sorted. Run `make fix-copies` to fix this.")
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
+ check_readme(args.fix_and_overwrite)
check_copies(args.fix_and_overwrite)
check_full_copies(args.fix_and_overwrite)
diff --git a/utils/check_dummies.py b/utils/check_dummies.py
index c1625036c4e3fc..d6c1c4b592f86a 100644
--- a/utils/check_dummies.py
+++ b/utils/check_dummies.py
@@ -26,7 +26,7 @@
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
-_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
+_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
DUMMY_CONSTANT = """
@@ -73,6 +73,8 @@ def read_init():
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
if backend is not None:
+ while not lines[line_index].startswith(" else:"):
+ line_index += 1
line_index += 1
objects = []
diff --git a/utils/check_inits.py b/utils/check_inits.py
index 18353581fcffd5..98d4caf010216b 100644
--- a/utils/check_inits.py
+++ b/utils/check_inits.py
@@ -25,10 +25,12 @@
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
+# Catches a one-line _import_struct = {xxx}
+_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
-# Catches a line if is_foo_available
-_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
+# Catches a line if not is_foo_available
+_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@@ -39,6 +41,10 @@
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
+# Catches a line with try:
+_re_try = re.compile(r"^\s*try:")
+# Catches a line with else:
+_re_else = re.compile(r"^\s*else:")
def find_backend(line):
@@ -70,6 +76,14 @@ def parse_init(init_file):
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
+ # If we have everything on a single line, let's deal with it.
+ if _re_one_line_import_struct.search(line):
+ content = _re_one_line_import_struct.search(line).groups()[0]
+ imports = re.findall("\[([^\]]+)\]", content)
+ for imp in imports:
+ objects.extend([obj[1:-1] for obj in imp.split(", ")])
+ line_index += 1
+ continue
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
@@ -81,11 +95,21 @@ def parse_init(init_file):
import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
- # If the line is an if is_backend_available, we grab all objects associated.
+ # If the line is an if not is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
+ # Check if the backend declaration is inside a try block:
+ if _re_try.search(lines[line_index - 1]) is None:
+ backend = None
+
if backend is not None:
line_index += 1
+ # Scroll until we hit the else block of try-except-else
+ while _re_else.search(lines[line_index]) is None:
+ line_index += 1
+
+ line_index += 1
+
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
@@ -130,11 +154,21 @@ def parse_init(init_file):
type_hint_objects = {"none": objects}
# Let's continue with backend-specific objects
while line_index < len(lines):
- # If the line is an if is_backemd_available, we grab all objects associated.
+ # If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
+ # Check if the backend declaration is inside a try block:
+ if _re_try.search(lines[line_index - 1]) is None:
+ backend = None
+
if backend is not None:
line_index += 1
+ # Scroll until we hit the else block of try-except-else
+ while _re_else.search(lines[line_index]) is None:
+ line_index += 1
+
+ line_index += 1
+
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
@@ -225,7 +259,7 @@ def get_transformers_submodules():
if fname == "__init__.py":
continue
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
- submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
+ submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
if len(submodule.split(".")) == 1:
submodules.append(submodule)
return submodules
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 2c8ca66abb8473..c3060b048aef18 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -36,6 +36,7 @@
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
+ "LongT5Stack",
"RealmBertModel",
"T5Stack",
"TFDPRSpanPredictor",
@@ -45,6 +46,7 @@
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
+ "OPTDecoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
@@ -58,6 +60,7 @@
"DetrDecoderWrapper", # Building part of bigger (tested) model.
"M2M100Encoder", # Building part of bigger (tested) model.
"M2M100Decoder", # Building part of bigger (tested) model.
+ "MCTCTEncoder", # Building part of bigger (tested) model.
"Speech2TextEncoder", # Building part of bigger (tested) model.
"Speech2TextDecoder", # Building part of bigger (tested) model.
"LEDEncoder", # Building part of bigger (tested) model.
@@ -91,26 +94,28 @@
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
+ "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
+ "OPTDecoderWrapper",
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
# trigger the common tests.
TEST_FILES_WITH_NO_COMMON_TESTS = [
- "decision_transformer/test_modeling_decision_transformer.py",
- "camembert/test_modeling_camembert.py",
- "mt5/test_modeling_flax_mt5.py",
- "mbart/test_modeling_mbart.py",
- "mt5/test_modeling_mt5.py",
- "pegasus/test_modeling_pegasus.py",
- "camembert/test_modeling_tf_camembert.py",
- "mt5/test_modeling_tf_mt5.py",
- "xlm_roberta/test_modeling_tf_xlm_roberta.py",
- "xlm_roberta/test_modeling_flax_xlm_roberta.py",
- "xlm_prophetnet/test_modeling_xlm_prophetnet.py",
- "xlm_roberta/test_modeling_xlm_roberta.py",
- "vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
- "vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
- "decision_transformer/test_modeling_decision_transformer.py",
+ "models/decision_transformer/test_modeling_decision_transformer.py",
+ "models/camembert/test_modeling_camembert.py",
+ "models/mt5/test_modeling_flax_mt5.py",
+ "models/mbart/test_modeling_mbart.py",
+ "models/mt5/test_modeling_mt5.py",
+ "models/pegasus/test_modeling_pegasus.py",
+ "models/camembert/test_modeling_tf_camembert.py",
+ "models/mt5/test_modeling_tf_mt5.py",
+ "models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
+ "models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
+ "models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
+ "models/xlm_roberta/test_modeling_xlm_roberta.py",
+ "models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
+ "models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
+ "models/decision_transformer/test_modeling_decision_transformer.py",
]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
@@ -145,6 +150,10 @@
"DetrForSegmentation",
"DPRReader",
"FlaubertForQuestionAnswering",
+ "FlavaImageCodebook",
+ "FlavaTextModel",
+ "FlavaImageModel",
+ "FlavaMultimodalModel",
"GPT2DoubleHeadsModel",
"LukeForMaskedLM",
"LukeForEntityClassification",
@@ -307,7 +316,12 @@ def check_models_are_in_init():
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
- """Get the model test files."""
+ """Get the model test files.
+
+ The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
+ considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
+ """
+
_ignore_files = [
"test_modeling_common",
"test_modeling_encoder_decoder",
@@ -318,20 +332,23 @@ def get_model_test_files():
"test_modeling_tf_encoder_decoder",
]
test_files = []
- for file_or_dir in os.listdir(PATH_TO_TESTS):
- path = os.path.join(PATH_TO_TESTS, file_or_dir)
- if os.path.isdir(path):
- filenames = [os.path.join(file_or_dir, file) for file in os.listdir(path)]
- else:
- filenames = [file_or_dir]
-
- for filename in filenames:
- if (
- os.path.isfile(os.path.join(PATH_TO_TESTS, filename))
- and "test_modeling" in filename
- and not os.path.splitext(filename)[0] in _ignore_files
- ):
- test_files.append(filename)
+ # Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
+ model_test_root = os.path.join(PATH_TO_TESTS, "models")
+ model_test_dirs = []
+ for x in os.listdir(model_test_root):
+ x = os.path.join(model_test_root, x)
+ if os.path.isdir(x):
+ model_test_dirs.append(x)
+
+ for target_dir in [PATH_TO_TESTS] + model_test_dirs:
+ for file_or_dir in os.listdir(target_dir):
+ path = os.path.join(target_dir, file_or_dir)
+ if os.path.isfile(path):
+ filename = os.path.split(path)[-1]
+ if "test_modeling" in filename and not os.path.splitext(filename)[0] in _ignore_files:
+ file = os.path.join(*path.split(os.sep)[1:])
+ test_files.append(file)
+
return test_files
@@ -509,7 +526,8 @@ def check_all_decorator_order():
if len(errors) > 0:
msg = "\n".join(errors)
raise ValueError(
- f"The parameterized decorator (and its variants) should always be first, but this is not the case in the following files:\n{msg}"
+ "The parameterized decorator (and its variants) should always be first, but this is not the case in the"
+ f" following files:\n{msg}"
)
@@ -708,7 +726,7 @@ def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
- with open(file, "r") as f:
+ with open(file, encoding="utf-8") as f:
code = f.read()
docstrings = code.split('"""')
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index 456ff4aedc94f0..375cdb662f3ab0 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -167,7 +167,7 @@ def sort_imports(file, check_only=True):
"""
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
"""
- with open(file, "r") as f:
+ with open(file, encoding="utf-8") as f:
code = f.read()
if "_import_structure" not in code:
@@ -227,7 +227,7 @@ def sort_imports(file, check_only=True):
return True
else:
print(f"Overwriting {file}.")
- with open(file, "w") as f:
+ with open(file, "w", encoding="utf-8") as f:
f.write("\n".join(main_blocks))
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 12ca41b413a0a9..462c75bc5c550c 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -1,5 +1,7 @@
docs/source/en/quicktour.mdx
docs/source/es/quicktour.mdx
+docs/source/en/pipeline_tutorial.mdx
+docs/source/en/autoclass_tutorial.mdx
docs/source/en/task_summary.mdx
docs/source/en/model_doc/speech_to_text.mdx
docs/source/en/model_doc/t5.mdx
@@ -18,6 +20,8 @@ src/transformers/models/big_bird/modeling_big_bird.py
src/transformers/models/blenderbot/modeling_blenderbot.py
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
src/transformers/models/convnext/modeling_convnext.py
+src/transformers/models/ctrl/modeling_ctrl.py
+src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deit/modeling_deit.py
@@ -28,10 +32,18 @@ src/transformers/models/glpn/modeling_glpn.py
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/hubert/modeling_hubert.py
+src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
+src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
+src/transformers/models/longformer/modeling_longformer.py
+src/transformers/models/longformer/modeling_tf_longformer.py
+src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py
+src/transformers/models/opt/modeling_opt.py
+src/transformers/models/opt/modeling_tf_opt.py
+src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
@@ -57,6 +69,7 @@ src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
+src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
-src/transformers/models/wavlm/modeling_wavlm.py
-src/transformers/models/ctrl/modeling_ctrl.py
+src/transformers/models/wavlm/modeling_wavlm.py
+src/transformers/models/yolos/modeling_yolos.py
diff --git a/utils/notification_service.py b/utils/notification_service.py
index c0f2cdb25fbba7..0c2de5baca6d18 100644
--- a/utils/notification_service.py
+++ b/utils/notification_service.py
@@ -98,8 +98,9 @@ def dicts_to_sum(objects: Union[Dict[str, Dict], List[dict]]):
class Message:
- def __init__(self, title: str, model_results: Dict, additional_results: Dict):
+ def __init__(self, title: str, ci_title: str, model_results: Dict, additional_results: Dict):
self.title = title
+ self.ci_title = ci_title
# Failures and success of the modeling tests
self.n_model_success = sum(r["success"] for r in model_results.values())
@@ -158,6 +159,10 @@ def time(self) -> str:
def header(self) -> Dict:
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
+ @property
+ def ci_title_section(self) -> Dict:
+ return {"type": "section", "text": {"type": "mrkdwn", "text": self.ci_title}}
+
@property
def no_failures(self) -> Dict:
return {
@@ -180,7 +185,10 @@ def failures(self) -> Dict:
"type": "section",
"text": {
"type": "plain_text",
- "text": f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in {self.time}.",
+ "text": (
+ f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
+ f" {self.time}."
+ ),
"emoji": True,
},
"accessory": {
@@ -343,6 +351,9 @@ def additional_failures(self) -> Dict:
def payload(self) -> str:
blocks = [self.header]
+ if self.ci_title:
+ blocks.append(self.ci_title_section)
+
if self.n_model_failures > 0 or self.n_additional_failures > 0:
blocks.append(self.failures)
@@ -378,7 +389,7 @@ def error_out():
print(json.dumps({"blocks": json.loads(payload)}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text="There was an issue running the tests.",
blocks=payload,
)
@@ -390,14 +401,28 @@ def post(self):
text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed."
self.thread_ts = client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
blocks=self.payload,
text=text,
)
def get_reply_blocks(self, job_name, job_result, failures, device, text):
- if len(failures) > 2500:
- failures = "\n".join(failures.split("\n")[:20]) + "\n\n[Truncated]"
+ """
+ failures: A list with elements of the form {"line": full test name, "trace": error trace}
+ """
+ # `text` must be less than 3001 characters in Slack SDK
+ # keep some room for adding "[Truncated]" when necessary
+ MAX_ERROR_TEXT = 3000 - len("[Truncated]")
+
+ failure_text = ""
+ for idx, error in enumerate(failures):
+ new_text = failure_text + f'*{error["line"]}*\n_{error["trace"]}_\n\n'
+ if len(new_text) > MAX_ERROR_TEXT:
+ # `failure_text` here has length <= 3000
+ failure_text = failure_text + "[Truncated]"
+ break
+ # `failure_text` here has length <= MAX_ERROR_TEXT
+ failure_text = new_text
title = job_name
if device is not None:
@@ -415,7 +440,7 @@ def get_reply_blocks(self, job_name, job_result, failures, device, text):
return [
{"type": "header", "text": {"type": "plain_text", "text": title.upper(), "emoji": True}},
content,
- {"type": "section", "text": {"type": "mrkdwn", "text": failures}},
+ {"type": "section", "text": {"type": "mrkdwn", "text": failure_text}},
]
def post_reply(self):
@@ -436,7 +461,7 @@ def post_reply(self):
print(json.dumps({"blocks": blocks}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text=f"Results for {job}",
blocks=blocks,
thread_ts=self.thread_ts["ts"],
@@ -459,7 +484,7 @@ def post_reply(self):
print(json.dumps({"blocks": blocks}))
client.chat_postMessage(
- channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"],
+ channel=os.environ["CI_SLACK_REPORT_CHANNEL_ID"],
text=f"Results for {job}",
blocks=blocks,
thread_ts=self.thread_ts["ts"],
@@ -494,7 +519,7 @@ def retrieve_artifact(name: str, gpu: Optional[str]):
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
if gpu is not None:
- name = f"{gpu}-gpu-docker_{name}"
+ name = f"{gpu}-gpu_{name}"
_artifact = {}
@@ -528,8 +553,8 @@ def add_path(self, path: str, gpu: str = None):
directories = filter(os.path.isdir, os.listdir())
for directory in directories:
- if directory.startswith("single-gpu-docker"):
- artifact_name = directory[len("single-gpu-docker") + 1 :]
+ if directory.startswith("single-gpu"):
+ artifact_name = directory[len("single-gpu") + 1 :]
if artifact_name in _available_artifacts:
_available_artifacts[artifact_name].single_gpu = True
@@ -538,8 +563,8 @@ def add_path(self, path: str, gpu: str = None):
_available_artifacts[artifact_name].add_path(directory, gpu="single")
- elif directory.startswith("multi-gpu-docker"):
- artifact_name = directory[len("multi-gpu-docker") + 1 :]
+ elif directory.startswith("multi-gpu"):
+ artifact_name = directory[len("multi-gpu") + 1 :]
if artifact_name in _available_artifacts:
_available_artifacts[artifact_name].multi_gpu = True
@@ -558,9 +583,15 @@ def add_path(self, path: str, gpu: str = None):
if __name__ == "__main__":
+
+ # This env. variable is set in workflow file (under the job `send_results`).
+ ci_event = os.environ["CI_EVENT"]
+
arguments = sys.argv[1:][0]
try:
models = ast.literal_eval(arguments)
+ # Need to change from elements like `models/bert` to `models_bert` (the ones used as artifact names).
+ models = [x.replace("models/", "models_") for x in models]
except SyntaxError:
Message.error_out()
raise ValueError("Errored out.")
@@ -604,7 +635,8 @@ def add_path(self, path: str, gpu: str = None):
if "stats" in artifact:
# Link to the GitHub Action job
model_results[model]["job_link"] = github_actions_job_links.get(
- f"Model tests ({model}, {artifact_path['gpu']}-gpu-docker)"
+ # The job names use `matrix.folder` which contain things like `models/bert` instead of `models_bert`
+ f"Model tests ({model.replace('models_', 'models/')}, {artifact_path['gpu']}-gpu)"
)
failed, success, time_spent = handle_test_results(artifact["stats"])
@@ -620,16 +652,16 @@ def add_path(self, path: str, gpu: str = None):
line = line.split()[0].replace("\n", "")
if artifact_path["gpu"] not in model_results[model]["failures"]:
- model_results[model]["failures"][artifact_path["gpu"]] = ""
+ model_results[model]["failures"][artifact_path["gpu"]] = []
- model_results[model]["failures"][
- artifact_path["gpu"]
- ] += f"*{line}*\n_{stacktraces.pop(0)}_\n\n"
+ model_results[model]["failures"][artifact_path["gpu"]].append(
+ {"line": line, "trace": stacktraces.pop(0)}
+ )
- if re.search("_tf_", line):
+ if re.search("test_modeling_tf_", line):
model_results[model]["failed"]["TensorFlow"][artifact_path["gpu"]] += 1
- elif re.search("_flax_", line):
+ elif re.search("test_modeling_flax_", line):
model_results[model]["failed"]["Flax"][artifact_path["gpu"]] += 1
elif re.search("test_modeling", line):
@@ -662,6 +694,11 @@ def add_path(self, path: str, gpu: str = None):
"Torch CUDA extension tests": "run_tests_torch_cuda_extensions_gpu_test_reports",
}
+ if ci_event == "push":
+ del additional_files["Examples directory"]
+ del additional_files["PyTorch pipelines"]
+ del additional_files["TensorFlow pipelines"]
+
additional_results = {
key: {
"failed": {"unclassified": 0, "single": 0, "multi": 0},
@@ -684,7 +721,7 @@ def add_path(self, path: str, gpu: str = None):
for artifact_path in available_artifacts[additional_files[key]].paths:
if artifact_path["gpu"] is not None:
additional_results[key]["job_link"] = github_actions_job_links.get(
- f"{key} ({artifact_path['gpu']}-gpu-docker)"
+ f"{key} ({artifact_path['gpu']}-gpu)"
)
artifact = retrieve_artifact(artifact_path["name"], artifact_path["gpu"])
stacktraces = handle_stacktraces(artifact["failures_line"])
@@ -704,13 +741,54 @@ def add_path(self, path: str, gpu: str = None):
line = line.split()[0].replace("\n", "")
if artifact_path["gpu"] not in additional_results[key]["failures"]:
- additional_results[key]["failures"][artifact_path["gpu"]] = ""
+ additional_results[key]["failures"][artifact_path["gpu"]] = []
+
+ additional_results[key]["failures"][artifact_path["gpu"]].append(
+ {"line": line, "trace": stacktraces.pop(0)}
+ )
+
+ # To find the PR number in a commit title, for example, `Add AwesomeFormer model (#99999)`
+ pr_number_re = re.compile(r"\(#(\d+)\)$")
+
+ title = f"š¤ Results of the {ci_event} tests."
+ # Add PR title with a link for push CI
+ ci_title = os.environ.get("CI_TITLE")
+ ci_url = os.environ.get("CI_COMMIT_URL")
- additional_results[key]["failures"][
- artifact_path["gpu"]
- ] += f"*{line}*\n_{stacktraces.pop(0)}_\n\n"
+ if ci_title is not None:
+ assert ci_url is not None
+ ci_title = ci_title.strip().split("\n")[0].strip()
+
+ # Retrieve the PR title and author login to complete the report
+ commit_number = ci_url.split("/")[-1]
+ ci_detail_url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit_number}"
+ ci_details = requests.get(ci_detail_url).json()
+ ci_author = ci_details["author"]["login"]
+
+ merged_by = None
+ # Find the PR number (if any) and change the url to the actual PR page.
+ numbers = pr_number_re.findall(ci_title)
+ if len(numbers) > 0:
+ pr_number = numbers[0]
+ ci_detail_url = f"https://api.github.com/repos/huggingface/transformers/pulls/{pr_number}"
+ ci_details = requests.get(ci_detail_url).json()
+
+ ci_author = ci_details["user"]["login"]
+ ci_url = f"https://github.com/huggingface/transformers/pull/{pr_number}"
+
+ merged_by = ci_details["merged_by"]["login"]
+
+ if merged_by is None:
+ ci_title = f"<{ci_url}|{ci_title}>\nAuthor: {ci_author}"
+ else:
+ ci_title = f"<{ci_url}|{ci_title}>\nAuthor: {ci_author} | Merged by: {merged_by}"
+
+ else:
+ ci_title = ""
- message = Message("š¤ Results of the scheduled tests.", model_results, additional_results)
+ message = Message(title, ci_title, model_results, additional_results)
- message.post()
- message.post_reply()
+ # send report only if there is any failure
+ if message.n_failures:
+ message.post()
+ message.post_reply()
diff --git a/utils/notification_service_deprecated.py b/utils/notification_service_deprecated.py
deleted file mode 100644
index b14bff1751921a..00000000000000
--- a/utils/notification_service_deprecated.py
+++ /dev/null
@@ -1,217 +0,0 @@
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Old script for Slack's notification service. Still here as the entire suite has not been moved to the newer implem.
-
-import os
-import re
-import sys
-
-from slack_sdk import WebClient
-
-
-def handle_test_results(test_results):
- expressions = test_results.split(" ")
-
- failed = 0
- success = 0
-
- # When the output is short enough, the output is surrounded by = signs: "== OUTPUT =="
- # When it is too long, those signs are not present.
- time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1]
-
- for i, expression in enumerate(expressions):
- if "failed" in expression:
- failed += int(expressions[i - 1])
- if "passed" in expression:
- success += int(expressions[i - 1])
-
- return failed, success, time_spent
-
-
-def format_for_slack(total_results, results, scheduled: bool, title: str):
- print(total_results, results)
- header = {
- "type": "header",
- "text": {
- "type": "plain_text",
- "text": title,
- "emoji": True,
- },
- }
-
- if total_results["failed"] > 0:
- total = {
- "type": "section",
- "fields": [
- {"type": "mrkdwn", "text": f"*Failures:*\nā {total_results['failed']} failures."},
- {"type": "mrkdwn", "text": f"*Passed:*\nā
{total_results['success']} tests passed."},
- ],
- }
- else:
- total = {
- "type": "section",
- "fields": [
- {"type": "mrkdwn", "text": "\nš All tests passed."},
- ],
- }
-
- blocks = [header, total]
-
- if total_results["failed"] > 0:
- for key, result in results.items():
- print(key, result)
- blocks.append({"type": "header", "text": {"type": "plain_text", "text": key, "emoji": True}})
- blocks.append(
- {
- "type": "section",
- "fields": [
- {
- "type": "mrkdwn",
- "text": f"*Results:*\n{result['failed']} failed, {result['success']} passed.",
- },
- {"type": "mrkdwn", "text": f"*Time spent:*\n{result['time_spent']}"},
- ],
- }
- )
- elif not scheduled:
- for key, result in results.items():
- blocks.append(
- {"type": "section", "fields": [{"type": "mrkdwn", "text": f"*{key}*\n{result['time_spent']}."}]}
- )
-
- footer = {
- "type": "section",
- "text": {
- "type": "mrkdwn",
- "text": f"",
- },
- }
-
- blocks.append(footer)
-
- blocks = {"blocks": blocks}
-
- return blocks
-
-
-if __name__ == "__main__":
- arguments = sys.argv[1:]
-
- if "scheduled" in arguments:
- arguments.remove("scheduled")
- scheduled = True
- else:
- scheduled = False
-
- if scheduled:
- # The scheduled run has several artifacts for each job.
- file_paths = {
- "TF Single GPU": {
- "common": "run_all_tests_tf_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_tf_gpu_test_reports/[].txt",
- },
- "Torch Single GPU": {
- "common": "run_all_tests_torch_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_torch_gpu_test_reports/[].txt",
- "examples": "run_all_tests_torch_gpu_test_reports/[].txt",
- },
- "TF Multi GPU": {
- "common": "run_all_tests_tf_multi_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_tf_multi_gpu_test_reports/[].txt",
- },
- "Torch Multi GPU": {
- "common": "run_all_tests_torch_multi_gpu_test_reports/[].txt",
- "pipeline": "run_all_tests_torch_multi_gpu_test_reports/[].txt",
- },
- "Torch Cuda Extensions Single GPU": {"common": "run_tests_torch_cuda_extensions_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Multi GPU": {
- "common": "run_tests_torch_cuda_extensions_multi_gpu_test_reports/[].txt"
- },
- }
- else:
- file_paths = {
- "TF Single GPU": {"common": "run_all_tests_tf_gpu_test_reports/[].txt"},
- "Torch Single GPU": {"common": "run_all_tests_torch_gpu_test_reports/[].txt"},
- "TF Multi GPU": {"common": "run_all_tests_tf_multi_gpu_test_reports/[].txt"},
- "Torch Multi GPU": {"common": "run_all_tests_torch_multi_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Single GPU": {"common": "run_tests_torch_cuda_extensions_gpu_test_reports/[].txt"},
- "Torch Cuda Extensions Multi GPU": {
- "common": "run_tests_torch_cuda_extensions_multi_gpu_test_reports/[].txt"
- },
- }
-
- client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"])
-
- if not scheduled:
- channel_id = os.environ["CI_SLACK_CHANNEL_ID"]
- elif scheduled and len(arguments):
- channel_id = os.environ["CI_SLACK_CHANNEL_ID_PAST_FUTURE"]
- else:
- channel_id = os.environ["CI_SLACK_CHANNEL_ID_DAILY"]
-
- if scheduled:
- title = "š¤ Results of the scheduled tests."
- else:
- title = "š¤ Self-push results"
-
- if len(arguments):
- title = f"{arguments} " + title
-
- try:
- results = {}
- for job, file_dict in file_paths.items():
-
- # Single return value for failed/success across steps of a same job
- results[job] = {"failed": 0, "success": 0, "time_spent": "", "failures": ""}
-
- for key, file_path in file_dict.items():
- try:
- with open(file_path.replace("[]", "stats")) as f:
- failed, success, time_spent = handle_test_results(f.read())
- results[job]["failed"] += failed
- results[job]["success"] += success
- results[job]["time_spent"] += time_spent[1:-1] + ", "
- with open(file_path.replace("[]", "summary_short")) as f:
- for line in f:
- if re.search("FAILED", line):
- results[job]["failures"] += line
- except FileNotFoundError:
- print("Artifact was not found, job was probably canceled.")
-
- # Remove the trailing ", "
- results[job]["time_spent"] = results[job]["time_spent"][:-2]
-
- test_results_keys = ["failed", "success"]
- total = {"failed": 0, "success": 0}
- for job, job_result in results.items():
- for result_key in test_results_keys:
- total[result_key] += job_result[result_key]
-
- if total["failed"] != 0 or scheduled:
- to_be_sent_to_slack = format_for_slack(total, results, scheduled, title)
-
- result = client.chat_postMessage(
- channel=channel_id,
- blocks=to_be_sent_to_slack["blocks"],
- )
-
- for job, job_result in results.items():
- if len(job_result["failures"]):
- client.chat_postMessage(
- channel=channel_id, text=f"{job}\n{job_result['failures']}", thread_ts=result["ts"]
- )
-
- except Exception as e:
- # Voluntarily catch every exception and send it to Slack.
- raise Exception(f"Setup error: no artifacts were found. Error: {e}") from e
diff --git a/utils/notification_service_doc_tests.py b/utils/notification_service_doc_tests.py
index 58ceb567adbdc5..d02b08b605e116 100644
--- a/utils/notification_service_doc_tests.py
+++ b/utils/notification_service_doc_tests.py
@@ -118,7 +118,10 @@ def failures(self) -> Dict:
"type": "section",
"text": {
"type": "plain_text",
- "text": f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in {self.time}.",
+ "text": (
+ f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
+ f" {self.time}."
+ ),
"emoji": True,
},
"accessory": {
@@ -286,7 +289,7 @@ def retrieve_artifact(name: str):
files = os.listdir(name)
for file in files:
try:
- with open(os.path.join(name, file)) as f:
+ with open(os.path.join(name, file), encoding="utf-8") as f:
_artifact[file.split(".")[0]] = f.read()
except UnicodeDecodeError as e:
raise ValueError(f"Could not open {os.path.join(name, file)}.") from e
diff --git a/utils/print_env.py b/utils/print_env.py
new file mode 100644
index 00000000000000..443ed6eab6c4b9
--- /dev/null
+++ b/utils/print_env.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# this script dumps information about the environment
+
+import os
+import sys
+
+import transformers
+
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+print("Python version:", sys.version)
+print("transformers version:", transformers.__version__)
+
+try:
+ import torch
+
+ print("Torch version:", torch.__version__)
+ print("Cuda available:", torch.cuda.is_available())
+ print("Cuda version:", torch.version.cuda)
+ print("CuDNN version:", torch.backends.cudnn.version())
+ print("Number of GPUs available:", torch.cuda.device_count())
+ print("NCCL version:", torch.cuda.nccl.version())
+except ImportError:
+ print("Torch version:", None)
+
+try:
+ import deepspeed
+
+ print("DeepSpeed version:", deepspeed.__version__)
+except ImportError:
+ print("DeepSpeed version:", None)
+
+try:
+ import tensorflow as tf
+
+ print("TensorFlow version:", tf.__version__)
+ print("TF GPUs available:", bool(tf.config.list_physical_devices("GPU")))
+ print("Number of TF GPUs available:", len(tf.config.list_physical_devices("GPU")))
+except ImportError:
+ print("TensorFlow version:", None)
diff --git a/utils/print_env_pt.py b/utils/print_env_pt.py
deleted file mode 100755
index 94451541f64664..00000000000000
--- a/utils/print_env_pt.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#!/usr/bin/env python3
-
-# coding=utf-8
-# Copyright 2020 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# this script dumps information about the environment
-
-import torch
-
-
-print("Torch version:", torch.__version__)
-print("Cuda available:", torch.cuda.is_available())
-print("Cuda version:", torch.version.cuda)
-print("CuDNN version:", torch.backends.cudnn.version())
-print("Number of GPUs available:", torch.cuda.device_count())
-print("NCCL version:", torch.cuda.nccl.version())
diff --git a/utils/release.py b/utils/release.py
index 5a9c15f6ae06b0..3bb75f0bebf45e 100644
--- a/utils/release.py
+++ b/utils/release.py
@@ -123,7 +123,7 @@ def pre_release_work(patch=False):
print(f"Updating version to {version}.")
global_version_update(version, patch=patch)
if not patch:
- print("Cleaning main README")
+ print("Cleaning main README, don't forget to run `make fix-copies`.")
clean_main_ref_in_model_list()
@@ -141,6 +141,8 @@ def post_release_work():
print(f"Updating version to {version}.")
global_version_update(version)
+ print("Cleaning main README, don't forget to run `make fix-copies`.")
+ clean_main_ref_in_model_list()
if __name__ == "__main__":
diff --git a/utils/sort_auto_mappings.py b/utils/sort_auto_mappings.py
new file mode 100644
index 00000000000000..ef985dc43cd4f4
--- /dev/null
+++ b/utils/sort_auto_mappings.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import re
+
+
+PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
+
+
+# re pattern that matches mapping introductions:
+# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
+_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
+# re pattern that matches identifiers in mappings
+_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
+
+
+def sort_auto_mapping(fname, overwrite: bool = False):
+ with open(fname, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ new_lines = []
+ line_idx = 0
+ while line_idx < len(lines):
+ if _re_intro_mapping.search(lines[line_idx]) is not None:
+ indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
+ # Start of a new mapping!
+ while not lines[line_idx].startswith(" " * indent + "("):
+ new_lines.append(lines[line_idx])
+ line_idx += 1
+
+ blocks = []
+ while lines[line_idx].strip() != "]":
+ # Blocks either fit in one line or not
+ if lines[line_idx].strip() == "(":
+ start_idx = line_idx
+ while not lines[line_idx].startswith(" " * indent + ")"):
+ line_idx += 1
+ blocks.append("\n".join(lines[start_idx : line_idx + 1]))
+ else:
+ blocks.append(lines[line_idx])
+ line_idx += 1
+
+ # Sort blocks by their identifiers
+ blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
+ new_lines += blocks
+ else:
+ new_lines.append(lines[line_idx])
+ line_idx += 1
+
+ if overwrite:
+ with open(fname, "w", encoding="utf-8") as f:
+ f.write("\n".join(new_lines))
+ elif "\n".join(new_lines) != content:
+ return True
+
+
+def sort_all_auto_mappings(overwrite: bool = False):
+ fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
+ diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
+
+ if not overwrite and any(diffs):
+ failures = [f for f, d in zip(fnames, diffs) if d]
+ raise ValueError(
+ f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
+ " this."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
+ args = parser.parse_args()
+
+ sort_all_auto_mappings(not args.check_only)
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index 16bf6348d387d2..7acebcabb4e473 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -15,6 +15,7 @@
import argparse
import collections
+import json
import os
import re
from contextlib import contextmanager
@@ -65,6 +66,32 @@ def clean_code(content):
return "\n".join(lines_to_keep)
+def get_all_tests():
+ """
+ Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`.
+
+ - folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded.
+ - folders under `tests/models`: `bert`, `gpt2`, etc.
+ - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
+ """
+ test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")
+
+ # test folders/files directly under `tests` folder
+ tests = os.listdir(test_root_dir)
+ tests = sorted(
+ list(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))
+ )
+
+ # model specific test folders
+ model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
+ model_test_folders = sorted(list(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders])))
+
+ tests.remove("tests/models")
+ tests = model_test_folders + tests
+
+ return tests
+
+
def diff_is_docstring_only(repo, branching_point, filename):
"""
Check if the diff is only in docstrings in a filename.
@@ -215,6 +242,67 @@ def get_test_dependencies(test_fname):
return [f for f in [*parent_imports, *current_dir_imports] if os.path.isfile(f)]
+def create_reverse_dependency_tree():
+ """
+ Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
+ """
+ modules = [
+ str(f.relative_to(PATH_TO_TRANFORMERS))
+ for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
+ ]
+ module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]
+
+ tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
+ test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]
+
+ return module_edges + test_edges
+
+
+def get_tree_starting_at(module, edges):
+ """
+ Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
+ starting at module], [list of edges starting at the preceding level], ...]
+ """
+ vertices_seen = [module]
+ new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
+ tree = [module]
+ while len(new_edges) > 0:
+ tree.append(new_edges)
+ final_vertices = list(set(edge[1] for edge in new_edges))
+ vertices_seen.extend(final_vertices)
+ new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]
+
+ return tree
+
+
+def print_tree_deps_of(module, all_edges=None):
+ """
+ Prints the tree of modules depending on a given module.
+ """
+ if all_edges is None:
+ all_edges = create_reverse_dependency_tree()
+ tree = get_tree_starting_at(module, all_edges)
+
+ # The list of lines is a list of tuples (line_to_be_printed, module)
+ # Keeping the modules lets us know where to insert each new lines in the list.
+ lines = [(tree[0], tree[0])]
+ for index in range(1, len(tree)):
+ edges = tree[index]
+ start_edges = set([edge[0] for edge in edges])
+
+ for start in start_edges:
+ end_edges = set([edge[1] for edge in edges if edge[0] == start])
+ # We will insert all those edges just after the line showing start.
+ pos = 0
+ while lines[pos][1] != start:
+ pos += 1
+ lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]
+
+ for line in lines:
+ # We don't print the refs that where just here to help build lines.
+ print(line[0])
+
+
def create_reverse_dependency_map():
"""
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
@@ -268,25 +356,28 @@ def create_reverse_dependency_map():
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
"feature_extraction_utils.py": "test_feature_extraction_common.py",
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
- "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
+ "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
"utils/hub.py": "utils/test_file_utils.py",
"modelcard.py": "utils/test_model_card.py",
"modeling_flax_utils.py": "test_modeling_flax_common.py",
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],
"modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"],
"models/auto/modeling_auto.py": [
- "auto/test_modeling_auto.py",
- "auto/test_modeling_tf_pytorch.py",
- "bort/test_modeling_bort.py",
- "dit/test_modeling_dit.py",
+ "models/auto/test_modeling_auto.py",
+ "models/auto/test_modeling_tf_pytorch.py",
+ "models/bort/test_modeling_bort.py",
+ "models/dit/test_modeling_dit.py",
],
- "models/auto/modeling_flax_auto.py": "auto/test_modeling_flax_auto.py",
+ "models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py",
"models/auto/modeling_tf_auto.py": [
- "auto/test_modeling_tf_auto.py",
- "auto/test_modeling_tf_pytorch.py",
- "bort/test_modeling_tf_bort.py",
+ "models/auto/test_modeling_tf_auto.py",
+ "models/auto/test_modeling_tf_pytorch.py",
+ "models/bort/test_modeling_tf_bort.py",
+ ],
+ "models/gpt2/modeling_gpt2.py": [
+ "models/gpt2/test_modeling_gpt2.py",
+ "models/megatron_gpt2/test_modeling_megatron_gpt2.py",
],
- "models/gpt2/modeling_gpt2.py": ["gpt2/test_modeling_gpt2.py", "megatron_gpt2/test_modeling_megatron_gpt2.py"],
"optimization.py": "optimization/test_optimization.py",
"optimization_tf.py": "optimization/test_optimization_tf.py",
"pipelines/base.py": "pipelines/test_pipelines_*.py",
@@ -350,7 +441,7 @@ def module_to_test_file(module_fname):
elif len(splits) > 0 and splits[0] == "utils":
default_test_file = f"tests/utils/test_utils_{module_name}"
elif len(splits) > 4 and splits[2] == "models":
- default_test_file = f"tests/{splits[3]}/test_{module_name}"
+ default_test_file = f"tests/models/{splits[3]}/test_{module_name}"
elif len(splits) > 2 and splits[2].startswith("generation"):
default_test_file = f"tests/generation/test_{module_name}"
elif len(splits) > 2 and splits[2].startswith("trainer"):
@@ -438,7 +529,7 @@ def sanity_check():
)
-def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
+def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
@@ -492,6 +583,42 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
with open(output_file, "w", encoding="utf-8") as f:
f.write(" ".join(test_files_to_run))
+ # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]
+
+ # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
+ # contains `tests` (i.e. when `setup.py` is changed).
+ if "tests" in test_files_to_run:
+ test_files_to_run = get_all_tests()
+
+ if json_output_file is not None:
+ test_map = {}
+ for test_file in test_files_to_run:
+ # `test_file` is a path to a test folder/file, starting with `tests/`. For example,
+ # - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
+ # - `tests/trainer/test_trainer.py` or `tests/trainer`
+ # - `tests/test_modeling_common.py`
+ names = test_file.split(os.path.sep)
+ if names[1] == "models":
+ # take the part like `models/bert` for modeling tests
+ key = "/".join(names[1:3])
+ elif len(names) > 2 or not test_file.endswith(".py"):
+ # test folders under `tests` or python files under them
+ # take the part like tokenization, `pipeline`, etc. for other test categories
+ key = "/".join(names[1:2])
+ else:
+ # common test files directly under `tests/`
+ key = "common"
+
+ if key not in test_map:
+ test_map[key] = []
+ test_map[key].append(test_file)
+
+ # sort the keys & values
+ keys = sorted(test_map.keys())
+ test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
+ with open(json_output_file, "w", encoding="UTF-8") as fp:
+ json.dump(test_map, fp, ensure_ascii=False)
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -501,6 +628,12 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
parser.add_argument(
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
)
+ parser.add_argument(
+ "--json_output_file",
+ type=str,
+ default="test_map.json",
+ help="Where to store the tests to run in a dictionary format mapping test categories to test files",
+ )
parser.add_argument(
"--diff_with_last_commit",
action="store_true",
@@ -513,8 +646,16 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
default=["tests"],
help="Only keep the test files matching one of those filters.",
)
+ parser.add_argument(
+ "--print_dependencies_of",
+ type=str,
+ help="Will only print the tree of modules depending on the file passed.",
+ default=None,
+ )
args = parser.parse_args()
- if args.sanity_check:
+ if args.print_dependencies_of is not None:
+ print_tree_deps_of(args.print_dependencies_of)
+ elif args.sanity_check:
sanity_check()
else:
repo = Repo(PATH_TO_TRANFORMERS)
@@ -525,7 +666,12 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
diff_with_last_commit = True
try:
- infer_tests_to_run(args.output_file, diff_with_last_commit=diff_with_last_commit, filters=args.filters)
+ infer_tests_to_run(
+ args.output_file,
+ diff_with_last_commit=diff_with_last_commit,
+ filters=args.filters,
+ json_output_file=args.json_output_file,
+ )
except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
with open(args.output_file, "w", encoding="utf-8") as f: