-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inference] Add t5 support for export and inference #267
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Awesome ! I have a few questions/issues regarding the generation code though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, and opened the discussion on generate
methods: would it be possible to re-use what we already have?
I noticed that all the code for this new model class lives in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking the time to merge the generation code. Most of my questions are related to this part that I did not review when it was first merged: i find it a bit cryptic where it diverges from the original transformers code I am more familiar with.
My main concern is how this is actually compatible with the generation config / generation parameters that can be passed to generate.
So if possible I would like to have some unit tests on that before formally approving.
(batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device | ||
), | ||
attention_mask, | ||
torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this extra one ? Is it for a new token that has been generated ?
decoder_attention_mask = torch.cat( | ||
[ | ||
torch.zeros((batch_size, num_padding_values), dtype=torch.int32), | ||
torch.ones((batch_size, 2), dtype=torch.int32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do these two input ids come from ?
|
||
@staticmethod | ||
def _initialize_past(past_key_values, num_padding_values): | ||
"""Initialize past_key_values with zeros -- the structure depends on `batch_axis`""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You actually keep the first two tokens of the past_key_values and pad left with num_padding_value zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not using this function at all, I only moved it out of _update_model_kwargs_for_xla_generation
so it would be easier to override.
new_past = () | ||
for past_layer in past_key_values: | ||
new_past_layer = list(past_layer) | ||
for i, _ in enumerate(new_past_layer[:2]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's basically for i in range(2)
, isn't it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but it was not me who contributed it so not sure if there is a specific intention here...
if synced_gpus: | ||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | ||
# The following logic allows an early break if all peers finished generating their sequence | ||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | ||
# send 0.0 if we finished, 1.0 otherwise | ||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | ||
# did all peers finish? the reduced sum will be 0.0 then | ||
if this_peer_finished_flag.item() == 0.0: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You keep this because you plan to use some kind of distribution ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no clue. I wanted to remove it at some point, but it seems to be able to possibly usable by XLA device...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am not sure that I understand the logic here but it seems that the operations used should be compatible with our system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand everything, but overall looks good to me.
Left a few nits and comments.
r""" | ||
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be | ||
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | ||
|
||
<Tip warning={true}> | ||
|
||
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() | ||
instead. For an overview of generation strategies and code examples, check the [following | ||
guide](../generation_strategies). | ||
|
||
</Tip> | ||
|
||
|
||
Parameters: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
The sequence used as a prompt for the generation. | ||
logits_processor (`LogitsProcessorList`, *optional*): | ||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | ||
used to modify the prediction scores of the language modeling head applied at each generation step. | ||
stopping_criteria (`StoppingCriteriaList`, *optional*): | ||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | ||
used to tell if the generation loop should stop. | ||
|
||
max_length (`int`, *optional*, defaults to 20): | ||
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated | ||
tokens. The maximum length of the sequence to be generated. | ||
pad_token_id (`int`, *optional*): | ||
The id of the *padding* token. | ||
eos_token_id (`Union[int, List[int]]`, *optional*): | ||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. | ||
output_attentions (`bool`, *optional*, defaults to `False`): | ||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | ||
returned tensors for more details. | ||
output_hidden_states (`bool`, *optional*, defaults to `False`): | ||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | ||
for more details. | ||
output_scores (`bool`, *optional*, defaults to `False`): | ||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. | ||
return_dict_in_generate (`bool`, *optional*, defaults to `False`): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
synced_gpus (`bool`, *optional*, defaults to `False`): | ||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | ||
seq_length (`Optional[int]`, defaults to `False`): | ||
Length of current input_ids sequence | ||
is_traced_inference (`bool`, defaults to `False`): | ||
Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores | ||
are computed inside the decoder. | ||
model_kwargs: | ||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model. | ||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`. | ||
|
||
Return: | ||
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or | ||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a | ||
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | ||
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if | ||
`model.config.is_encoder_decoder=True`. | ||
|
||
Examples: | ||
|
||
```python | ||
>>> from transformers import AutoTokenizer | ||
>>> from optimum.neuron import NeuronModelForSeq2SeqLM | ||
|
||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-small") | ||
>>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1} | ||
>>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes) | ||
|
||
>>> input_prompt = "translate English to German: Lets eat good food." | ||
>>> inputs = tokenizer(input_prompt, return_tensors="pt") | ||
|
||
>>> outputs = model.greedy_search(input_ids) | ||
|
||
>>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs] | ||
``` | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if we should copy/paste all the docstring from functions we override.
I am personally not too much in favor of doing so because it adds maintainance on our side for not so much added value. But it's just a thought.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to add the explanation for is_traced_inference
, felt it would be better to have it completed, especially we will tailor quite a lot of those args to be compatible with neuron in near future, it would be better to have a complete docstring explaining that.
if synced_gpus: | ||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | ||
# The following logic allows an early break if all peers finished generating their sequence | ||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | ||
# send 0.0 if we finished, 1.0 otherwise | ||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | ||
# did all peers finish? the reduced sum will be 0.0 then | ||
if this_peer_finished_flag.item() == 0.0: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am not sure that I understand the logic here but it seems that the operations used should be compatible with our system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for this awesome model ! And thanks also for adding the tests, I think it was worth the extra time since you caught some issues.
What's in the PR
Quick Tests
Exporter
optimum-cli export neuron --model t5-small --task text2text-generation --batch_size 1 --sequence_length 18 --num_beams 4 t5_small_neuron/
Inference
[Caveat] Beam search is not working yet. Got the following error while running the beam search with the official example. Will debug and add support in a coming PR.Error Log Beam Search
found the issue, in the notebook, runtime was not initialized for beam search, leading to the error.
Next Steps
neuronx-distributed