Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add t5 support for export and inference #267

Merged
merged 38 commits into from
Dec 4, 2023
Merged

Conversation

JingyaHuang
Copy link
Collaborator

@JingyaHuang JingyaHuang commented Oct 23, 2023

What's in the PR

  • T5 exporter support (2 parts: encoder and decoder)
  • T5 inference support (Seq2Seq modeling)
  • Tests

Quick Tests

Exporter

optimum-cli export neuron --model t5-small --task text2text-generation --batch_size 1 --sequence_length 18 --num_beams 4 t5_small_neuron/

Inference

from optimum.neuron import NeuronModelForSeq2SeqLM
from transformers import AutoTokenizer

model_id = "t5-small"
input_shapes = {
    "batch_size": 1,
    "sequence_length": 64,
    "num_beams": 4,
}
neuron_model = NeuronModelForSeq2SeqLM.from_pretrained(model_id, export=True, dynamic_batch_size=False, **input_shapes)
save_path = "t5_small_neuronx/"
neuron_model.save_pretrained(save_path)
del neuron_model

neuron_model = NeuronModelForSeq2SeqLM.from_pretrained(save_path)
tokenizer = AutoTokenizer.from_pretrained(save_path)
prompt = "translate English to German: Lets eat good food."
inputs = tokenizer(prompt, return_tensors="pt")
num_return_sequences = 1

output = neuron_model.generate(
    **inputs,
    num_return_sequences=num_return_sequences,
)
results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print("Results:")
for i, summary in enumerate(results):
    print(i + 1, summary)

[Caveat] Beam search is not working yet. Got the following error while running the beam search with the official example. Will debug and add support in a coming PR.

Error Log Beam Search
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/home/ubuntu/optimum-neuron/optimum/neuron/generation/utils.py:824: UserWarning: use_cache is not supported for generation on Neuron devices, switching to use_cache=False.
  warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.")
2023-Nov-05 17:52:36.718701 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input13 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718727 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input24 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718734 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input16 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718744 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input21 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718751 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input12 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718756 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input25 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718767 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input19 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718771 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input22 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718777 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input5 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718785 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input15 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718791 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input26 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718798 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input18 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718804 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input23 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718811 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input28 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718817 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input6 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718824 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input11 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718831 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input14 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718838 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input27 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718843 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input8 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718853 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input7 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718863 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input10 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718870 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input17 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718879 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input20 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718888 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor input9 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718898 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output9 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718907 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output11 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718916 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output26 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718926 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output16 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718936 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output7 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718945 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output8 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718954 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output12 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718963 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output3 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718970 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output21 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718979 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output17 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718986 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output6 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.718995 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output25 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719007 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output13 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719022 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output20 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719034 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output18 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719045 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output5 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719053 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output24 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719064 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output14 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719073 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output23 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719080 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output19 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719089 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output4 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719099 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output10 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719108 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output15 allocated on logical nc 0 must be allocated on same nc as model (1)
2023-Nov-05 17:52:36.719117 32119:32119 ERROR  NMGR:dlr_check_tensor_set_on_same_tpb_v2     Tensor output22 allocated on logical nc 0 must be allocated on same nc as model (1)
Traceback (most recent call last):
  File "test_t5.py", line 923, in <module>
    output = model.generate(
  File "test_t5.py", line 390, in generate
    output = super().generate(
  File "/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/optimum-neuron/optimum/neuron/generation/utils.py", line 1041, in generate
    return self.beam_search(
  File "test_t5.py", line 497, in beam_search
    next_token_scores, next_tokens, next_indices = self(
  File "/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "test_t5.py", line 418, in forward
    decoder_outputs = self.decoder(
  File "/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torch_neuronx/xla_impl/trace/___torch_mangle_3.py", line 65, in forward
    _0 = getattr(states22, "0")
    _24 = [argument_1, argument_2, argument_4, argument_5, argument_6, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23]
    _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, = ops.neuron.forward_v2(_24, model)
                                                                                                                                             ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _52 = [_25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51]
    return _52

Traceback of TorchScript, original code (most recent call last):
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/_ops.py(442): __call__
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py(101): forward
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/nn/modules/module.py(1194): _call_impl
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py(422): create_neuron_model
/home/ubuntu/pyvenv/aws_neuron_venv_2.15/lib/python3.8/site-packages/torch_neuronx/xla_impl/trace.py(395): trace
test_t5.py(825): trace_decoder
test_t5.py(908): <module>
RuntimeError: Failed to execute the model status=2 message=Invalid

found the issue, in the notebook, runtime was not initialized for beam search, leading to the error.

Next Steps

  • Add parallelism support from neuronx-distributed
  • Documentation
  • Perhaps wider range of test and refactoring if I have the bandwidth.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@JingyaHuang JingyaHuang marked this pull request as ready for review November 9, 2023 00:29
optimum/exporters/neuron/model_configs.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
@dacorvo
Copy link
Collaborator

dacorvo commented Nov 14, 2023

Awesome ! I have a few questions/issues regarding the generation code though.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, and opened the discussion on generate methods: would it be possible to re-use what we already have?

optimum/exporters/neuron/base.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/config.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/config.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/config.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/config.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_seq2seq.py Show resolved Hide resolved
tests/generation/test_hub.py Show resolved Hide resolved
tests/generation/test_hub.py Outdated Show resolved Hide resolved
@dacorvo
Copy link
Collaborator

dacorvo commented Nov 24, 2023

I noticed that all the code for this new model class lives in modeling_seq_2seq.py which is fine.
However, when I submitted NeuronModelForCausalLM, I was told to split it in modeling_decoder.py and modeling.py.
I think this is inconsistent and i will eventually push a pull-request to have all causal lm classes in modeling_decoder.

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the time to merge the generation code. Most of my questions are related to this part that I did not review when it was first merged: i find it a bit cryptic where it diverges from the original transformers code I am more familiar with.
My main concern is how this is actually compatible with the generation config / generation parameters that can be passed to generate.
So if possible I would like to have some unit tests on that before formally approving.

optimum/neuron/generation/utils.py Outdated Show resolved Hide resolved
(batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device
),
attention_mask,
torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this extra one ? Is it for a new token that has been generated ?

decoder_attention_mask = torch.cat(
[
torch.zeros((batch_size, num_padding_values), dtype=torch.int32),
torch.ones((batch_size, 2), dtype=torch.int32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these two input ids come from ?


@staticmethod
def _initialize_past(past_key_values, num_padding_values):
"""Initialize past_key_values with zeros -- the structure depends on `batch_axis`"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You actually keep the first two tokens of the past_key_values and pad left with num_padding_value zeros.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not using this function at all, I only moved it out of _update_model_kwargs_for_xla_generation so it would be easier to override.

new_past = ()
for past_layer in past_key_values:
new_past_layer = list(past_layer)
for i, _ in enumerate(new_past_layer[:2]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's basically for i in range(2), isn't it ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but it was not me who contributed it so not sure if there is a specific intention here...

optimum/neuron/generation/utils.py Show resolved Hide resolved
Comment on lines +790 to +798
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You keep this because you plan to use some kind of distribution ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no clue. I wanted to remove it at some point, but it seems to be able to possibly usable by XLA device...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am not sure that I understand the logic here but it seems that the operations used should be compatible with our system.

optimum/neuron/generation/utils.py Outdated Show resolved Hide resolved
optimum/neuron/generation/utils.py Outdated Show resolved Hide resolved
tests/generation/test_generate.py Outdated Show resolved Hide resolved
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand everything, but overall looks good to me.
Left a few nits and comments.

optimum/exporters/neuron/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/__main__.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/__main__.py Outdated Show resolved Hide resolved
optimum/neuron/generation/utils.py Show resolved Hide resolved
Comment on lines +659 to +734
r"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

<Tip warning={true}>

In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).

</Tip>


Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.

max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
seq_length (`Optional[int]`, defaults to `False`):
Length of current input_ids sequence
is_traced_inference (`bool`, defaults to `False`):
Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
are computed inside the decoder.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.

Return:
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.

Examples:

```python
>>> from transformers import AutoTokenizer
>>> from optimum.neuron import NeuronModelForSeq2SeqLM

>>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
>>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1}
>>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes)

>>> input_prompt = "translate English to German: Lets eat good food."
>>> inputs = tokenizer(input_prompt, return_tensors="pt")

>>> outputs = model.greedy_search(input_ids)

>>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs]
```
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if we should copy/paste all the docstring from functions we override.
I am personally not too much in favor of doing so because it adds maintainance on our side for not so much added value. But it's just a thought.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add the explanation for is_traced_inference, felt it would be better to have it completed, especially we will tailor quite a lot of those args to be compatible with neuron in near future, it would be better to have a complete docstring explaining that.

Comment on lines +790 to +798
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am not sure that I understand the logic here but it seems that the operations used should be compatible with our system.

optimum/neuron/modeling_base.py Outdated Show resolved Hide resolved
optimum/neuron/modeling_base.py Show resolved Hide resolved
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for this awesome model ! And thanks also for adding the tests, I think it was worth the extra time since you caught some issues.

@JingyaHuang JingyaHuang merged commit aabcedb into main Dec 4, 2023
7 checks passed
@JingyaHuang JingyaHuang deleted the add-t5-export branch December 4, 2023 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants