Skip to content

Commit

Permalink
Merge pull request huggingface#2 from DaryaTereshchenko/changes_to_pr1
Browse files Browse the repository at this point in the history
add fixes and documentation
  • Loading branch information
DaryaTereshchenko authored Nov 2, 2024
2 parents eafd847 + c70f864 commit 8350215
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 100 deletions.
26 changes: 0 additions & 26 deletions =0.26.0

This file was deleted.

1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ Flax), PyTorch, and/or TensorFlow.
| [PLBart](model_doc/plbart) ||||
| [PoolFormer](model_doc/poolformer) ||||
| [Pop2Piano](model_doc/pop2piano) ||||
| [Prism](model_doc/prism) ||||
| [ProphetNet](model_doc/prophetnet) ||||
| [PVT](model_doc/pvt) ||||
| [PVTv2](model_doc/pvt_v2) ||||
Expand Down
110 changes: 110 additions & 0 deletions docs/source/en/model_doc/prism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# PRISM

## Overview

The `Prism` model, a state-of-the-art multilingual neural machine translation (NMT) system developed for translation. The model supports translation across 39 languages, leveraging a zero-shot paraphrasing approach that does not require human judgments for training.

The `Prism` model was designed to be a lexically/syntactically unbiased paraphraser. The core idea is to treat paraphrasing as a zero-shot translation task, which allows the model to cover a wide range of languages effectively.

The model was proposed in [Automatic Machine Translation Evaluation in Many Languages via Zero-Shot Paraphrasing](https://aclanthology.org/2020.emnlp-main.8.pdf) by Brian Thompson and Matt Post.

The abstract from the paper is the following:

*We frame the task of machine translation evaluation as one of scoring machine translation output with a sequence-to-sequence paraphraser, conditioned on a human reference. Wepropose training the paraphraser as a multilingual NMT system, treating paraphrasing as a zero-shot translation task (e.g., Czech to Czech). This results in the paraphraser’s output mode being centered around a copy of the input sequence, which represents the best case scenario where the MT system output matches a human reference. Our method is simple and intuitive, and does not require human judgements for training. Our single model (trained in 39 languages) outperforms or statistically ties with all prior metrics on the WMT 2019 segment-level shared metrics task in all languages (excluding Gujarati where the model had no training data). We also explore using our model for the task of quality estimation as a metric—conditioning on the source instead of the reference—and find that it significantly outperforms every submission to the WMT2019 shared task on quality estimation in every language pair.*

This model was contributed by [dariast](https://huggingface.co/dariast/).
The original code can be found [here](https://github.com/thompsonb/prism/tree/master) and the original documentation is found [here](https://github.com/thompsonb/prism/blob/master/translation/README.md).


## Usage tips

To use `PrismTokenizer`, ensure that the `sentencepiece` package is installed, as it is a required dependency for handling multilingual tokenization.

```bash
pip install sentencepiece
```

## Example
```python
from transformers import PrismForConditionalGeneration, PrismTokenizer

uk_text = "Життя як коробка шоколаду"
ja_text = "人生はチョコレートの箱のようなもの。"

model = PrismForConditionalGeneration.from_pretrained("dariast/prism")
tokenizer = PrismTokenizer.from_pretrained("dariast/prism")

# Translate Ukrainian to French
tokenizer.src_lang = "uk"
encoded_uk = tokenizer(uk_text, return_tensors="pt")
generated_tokens = model.generate(**encoded_uk, forced_bos_token_id=tokenizer.get_lang_id("fr"), max_new_tokens=20)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => 'La vie comme une boîte de chocolat.'

# Translate Japanese to English
tokenizer.src_lang = "ja"
encoded_ja = tokenizer(ja_text, return_tensors="pt")
generated_tokens = model.generate(**encoded_ja, forced_bos_token_id=tokenizer.get_lang_id("en"), max_new_tokens=20)
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
# => 'Life is like a box of chocolate.'
```

## Languages Covered
Albanian (sq), Arabic (ar), Bengali (bn), Bulgarian (bg), Catalan; Valencian (ca), Chinese (zh), Croatian (hr), Czech (cs), Danish (da), Dutch (nl), English (en), Esperanto (eo), Estonian (et), Finnish (fi), French (fr), German (de), Greek, Modern (el), Hebrew (modern) (he), Hungarian (hu), Indonesian (id), Italian (it), Japanese (ja), Kazakh (kk), Latvian (lv), Lithuanian (lt), Macedonian (mk), Norwegian (no), Polish (pl), Portuguese (pt), Romanian, Moldovan (ro), Russian (ru), Serbian (sr), Slovak (sk), Slovene (sl), Spanish; Castilian (es), Swedish (sv), Turkish (tr), Ukrainian (uk), Vietnamese (vi).


## Resources

- [Translation task guide](../tasks/translation)

## PrismConfig

[[autodoc]] PrismConfig

## PrismTokenizer

[[autodoc]] PrismTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary

## PrismModel

[[autodoc]] PrismModel
- forward

## PrismForConditionalGeneration

[[autodoc]] PrismForConditionalGeneration
- forward

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
* [Prism](https://huggingface.co/docs/transformers/model_doc/prism)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
Expand Down Expand Up @@ -255,6 +256,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
* [Prism](https://huggingface.co/docs/transformers/model_doc/prism)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [mBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
("prism", "PrismConfig"),
("prophetnet", "ProphetNetConfig"),
("pvt", "PvtConfig"),
("pvt_v2", "PvtV2Config"),
Expand Down Expand Up @@ -534,6 +535,7 @@
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
("prism", "Prism"),
("prophetnet", "ProphetNet"),
("pvt", "PVT"),
("pvt_v2", "PVTv2"),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
("pixtral", "PixtralVisionModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prism", "PrismModel"),
("prophetnet", "ProphetNetModel"),
("pvt", "PvtModel"),
("pvt_v2", "PvtV2Model"),
Expand Down Expand Up @@ -434,6 +435,7 @@
("pegasus_x", "PegasusXForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("pop2piano", "Pop2PianoForConditionalGeneration"),
("prism", "PrismForConditionalGeneration"),
("qdqbert", "QDQBertForMaskedLM"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForMaskedLM"),
Expand Down Expand Up @@ -891,6 +893,7 @@
("pegasus", "PegasusForConditionalGeneration"),
("pegasus_x", "PegasusXForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("prism", "PrismForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForTextToText"),
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/prism/configuration_prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@


class PrismConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a `PrismModel`. It is used to instantiate a
r"""
This is the configuration class to store the configuration of a [`PrismModel`]. It is used to instantiate an
Prism model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Prism architecture as described in the paper.
with the defaults will yield a similar configuration to that of the Prism
[dariast/prism](https://huggingface.co/dariast/prism) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 65400):
Expand Down Expand Up @@ -86,10 +91,10 @@ class PrismConfig(PretrainedConfig):
```python
>>> from transformers import PrismConfig, PrismModel
>>> # Initializing a Prism facebook/prism style configuration
>>> # Initializing a Prism dariast/prism style configuration
>>> configuration = PrismConfig()
>>> # Initializing a model (with random weights) from the facebook/prism style configuration
>>> # Initializing a model (with random weights) from the dariast/prism style configuration
>>> model = PrismModel(configuration)
>>> # Accessing the model configuration
Expand Down
39 changes: 1 addition & 38 deletions src/transformers/models/prism/modeling_prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def __init__(self, config: PrismConfig):
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = PrismScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)

self.encoder = PrismEncoder(config, self.shared)
self.decoder = PrismDecoder(config, self.shared)

Expand Down Expand Up @@ -1517,7 +1518,6 @@ class PrismForConditionalGeneration(PrismPreTrainedModel, GenerationMixin):
def __init__(self, config: PrismConfig):
super().__init__(config)
self.model = PrismModel(config)
# self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -1620,43 +1620,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/prism/tokenization_prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ def convert_tokens_to_string(self, tokens):
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
# Skip language tokens during decoding
if token in self.lang_code_to_token.values():
continue
# Ensure special tokens are not decoded with the sentencepiece model
if token in self.all_special_tokens:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
Expand Down
Loading

0 comments on commit 8350215

Please sign in to comment.