diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 70f16ab23..b930d54c4 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -2,7 +2,7 @@ import time import torch -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed from optimum.neuron import NeuronModelForCausalLM @@ -35,7 +35,6 @@ def generate(model, tokenizer, prompts, length, temperature): sample_output = model.generate( **tokens, do_sample=True, - min_length=length, max_length=length, temperature=temperature, ) @@ -68,7 +67,10 @@ def generate(model, tokenizer, prompts, length, temperature): "--save_dir", type=str, help="The save directory. Allows to avoid recompiling the model every time." ) parser.add_argument("--compare", action="store_true", help="Compare with the genuine transformers model on CPU.") + parser.add_argument("--seed", type=int, default=None, help="Pass a seed for reproducibility.") args = parser.parse_args() + if args.seed is not None: + set_seed(args.seed) prompts = args.prompts.split("|") batch_size = len(prompts) model = load_llm_optimum(args.model, batch_size, args.num_cores, args.auto_cast_type) diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index 3a6fdba0e..1a2bffa23 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -14,9 +14,9 @@ # limitations under the License. """NeuronModelForXXX classes for inference on neuron devices using the same API as Transformers.""" +import copy import logging -import warnings -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Union import torch from transformers import ( @@ -33,10 +33,10 @@ GenerationMixin, LogitsProcessorList, LogitsWarper, - SampleDecoderOnlyOutput, StoppingCriteriaList, TopKLogitsWarper, ) +from transformers.generation.utils import GenerationMode from transformers.modeling_outputs import ( BaseModelOutputWithPooling, MaskedLMOutput, @@ -55,7 +55,7 @@ from pathlib import Path from tempfile import TemporaryDirectory - from transformers import BaseStreamer, GenerationConfig, PretrainedConfig + from transformers import GenerationConfig, PretrainedConfig logger = logging.getLogger(__name__) @@ -606,16 +606,8 @@ def forward( input_ids: torch.Tensor, cache_ids: torch.Tensor, start_ids: torch.Tensor = None, - output_hidden_states: bool = False, - output_attentions: bool = False, - attention_mask: torch.Tensor = None, return_dict: bool = True, ): - if output_hidden_states or output_attentions or attention_mask is not None: - warnings.warn( - "Warning: These arguments are not used by forward(): \ - (output_hidden_states, output_attentions, attention_mask)" - ) # Evaluate the output logits, storing the current key and values at the indices specified by cache_ids out_logits = self.model.forward(input_ids, cache_ids, start_ids) out_logits = out_logits[:, None, :] @@ -624,25 +616,11 @@ def forward( return ModelOutput([("logits", out_logits)]) return (out_logits,) - def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: - # Sanity checks - if kwargs.get("past_key_values", None) is not None: - raise ValueError("This model does not support dynamic key, value cache.") - batch_size, sequence_length = input_ids.shape - if batch_size != self.batch_size: - raise ValueError( - f"The specified batch_size ({batch_size}) does not match the model static batch size ({self.batch_size})" - ) - if sequence_length > self.max_length: - raise ValueError( - f"The current sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})" - ) + def prepare_inputs_for_generation( + self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs + ) -> Dict[str, torch.Tensor]: # convert attention_mask to start_ids - attention_mask = None start_ids = None - if "attention_mask" in kwargs: - attention_mask = kwargs["attention_mask"] - if attention_mask is not None: _, start_ids = attention_mask.max(axis=1) @@ -669,6 +647,130 @@ def can_generate(self) -> bool: """Returns True to validate the check made in `GenerationMixin.generate()`.""" return True + def generate( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + generation_config: Optional["GenerationConfig"] = None, + **kwargs, + ) -> torch.LongTensor: + r""" + A streamlined generate() method overriding the transformers.GenerationMixin.generate() method. + + This method uses the same logits processors/warpers and stopping criterias as the transformers library + `generate()` method but restricts the generation to greedy search and sampling. + + It does not support transformers `generate()` advanced options. + + Please refer to https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate + for details on generation configuration. + + Parameters: + input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + generation_config (`~transformers.generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~transformers.generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + + Returns: + `torch.Tensor`: A `torch.FloatTensor`. + """ + # The actual generation configuration is a combination of config and parameters + generation_config = copy.deepcopy(self.generation_config if generation_config is None else generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # Check model kwargs are actually used by either prepare_inputs_for_generation or forward + self._validate_model_kwargs(model_kwargs) + generation_config.validate() + + unsupported_generation_flags = [ + "output_attentions", + "output_hidden_states", + "output_scores", + "return_dict_in_generate", + ] + for flag in unsupported_generation_flags: + if getattr(generation_config, flag, False): + raise ValueError("{flag} is not supported for generation.") + + # Verify that the inputs are compatible with the model static input dimensions + batch_size, sequence_length = input_ids.shape + if batch_size != self.batch_size: + raise ValueError( + f"The specified batch_size ({batch_size}) does not match the model static batch size ({self.batch_size})" + ) + if sequence_length > self.max_length: + raise ValueError( + f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})" + ) + min_length = generation_config.min_length + if min_length > self.max_length: + raise ValueError( + f"The minimum generation length ({min_length}) exceeds the model static sequence length ({self.max_length})" + ) + max_length = generation_config.max_length + if min_length > self.max_length: + logger.warning( + f"The maximum generation length ({max_length}) exceeds the model static sequence length ({self.max_length})" + ) + + # Instantiate transformers library processors and criterias + logits_processor = self._get_logits_processor( + generation_config, + input_ids_seq_length=input_ids.shape[-1], + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=LogitsProcessorList(), + ) + stopping_criteria = self._get_stopping_criteria(generation_config, stopping_criteria=StoppingCriteriaList()) + + # Special tokens are required for generation + eos_token_id = generation_config.eos_token_id + # This is not supposed to happen for any of the models we support + assert eos_token_id is not None and not isinstance(eos_token_id, list) + if generation_config.pad_token_id is None: + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # Drop the current generation context and clear the Key/Value cache + self.reset_generation() + + generation_mode = self._get_generation_mode(generation_config, None) + if generation_mode == GenerationMode.GREEDY_SEARCH: + return self.greedy_search( + input_ids, + attention_mask=attention_mask, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + **model_kwargs, + ) + elif generation_mode == GenerationMode.SAMPLE: + logits_warper = self._get_logits_warper(generation_config) + last_warper = logits_warper[-1] + fast_topk = isinstance(last_warper, TopKLogitsWarper) + if fast_topk: + # Replace the last warping operation by a faster alternative + logits_warper[-1] = self.FastTopKLogitsWarper(last_warper.top_k, last_warper.filter_value) + return self.sample( + input_ids, + attention_mask=attention_mask, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + logits_warper=logits_warper, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + **model_kwargs, + ) + else: + raise ValueError("Unsupported generation mode") + class FastTopKLogitsWarper(LogitsWarper): r"""Returns [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores.""" @@ -681,24 +783,17 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # Remove all tokens with a probability less than the last token of the top-k return torch.topk(scores, top_k) - # Adapted from transformers.generation.utils.GenerationMixin.sample def sample( self, input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, + eos_token_id: int, + pad_token_id: int, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + logits_warper: LogitsProcessorList, + attention_mask: Optional[torch.Tensor] = None, **model_kwargs, - ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]: + ) -> torch.LongTensor: r""" This is a simplified version of the transformers `GenerationMixin.sample()` method that is optimized for neuron inference. @@ -710,193 +805,202 @@ def sample( Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): + eos_token_id (`int`): + The id of the *end-of-sequence* token. + pad_token_id (`int`): + The id of the *padding* token. + logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): + logits_warper (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. + Additional model specific kwargs will be forwarded to the `forward` function of the model. Return: - [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: - A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens. """ - # We don't support all parameters - if synced_gpus: - raise ValueError("Neuron models cannot run on GPUs.") - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = self.validate_stopping_criteria(stopping_criteria, max_length) - # Modifications to the original algorithm are all conditioned by the fast_topk boolean - fast_topk = False - if logits_warper is None: - logits_warper = LogitsProcessorList() - else: - last_warper = logits_warper[-1] - fast_topk = isinstance(last_warper, TopKLogitsWarper) - if fast_topk: - # Replace the last warping operation by a faster alternative - logits_warper[-1] = self.FastTopKLogitsWarper(last_warper.top_k, last_warper.filter_value) - - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + return self.generate_tokens( + input_ids, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + do_sample=True, + logits_warper=logits_warper, + attention_mask=attention_mask, + **model_kwargs, ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate + + def greedy_search( + self, + input_ids: torch.LongTensor, + eos_token_id: int, + pad_token_id: int, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, + ) -> torch.LongTensor: + r""" + This is a simplified version of the transformers `GenerationMixin.greedy_search()` method that is optimized for neuron inference. + + Please refer to https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.greedy_search. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + pad_token_id (`int`): + The id of the *padding* token. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. + + Return: + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens. + + """ + return self.generate_tokens( + input_ids, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + do_sample=False, + attention_mask=attention_mask, + **model_kwargs, ) - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + def generate_tokens( + self, + input_ids: torch.LongTensor, + eos_token_id: int, + pad_token_id: int, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + do_sample: bool, + logits_warper: Optional[LogitsProcessorList] = None, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, + ) -> torch.LongTensor: + r""" + Generate tokens using sampling or greedy search. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + pad_token_id (`int`): + The id of the *padding* token. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + do_sample (`bool`): Sample new tokens or simply takes the one with the highest score. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. + + Return: + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens. + + """ + if do_sample: + warper_indices = len(logits_warper) > 0 and isinstance(logits_warper[-1], self.FastTopKLogitsWarper) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - # This is specific to Neuron models - self.reset_generation() - # auto-regressive generation while True: # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - if fast_topk: - # Get [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores - next_token_scores, next_token_indices = logits_warper(input_ids, next_token_scores) + + if do_sample: + next_tokens = self.sample_token( + next_token_scores, + logits_warper=logits_warper, + warper_indices=warper_indices, + **model_kwargs, + ) else: - next_token_scores = logits_warper(input_ids, next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - if fast_topk: - # Expand the [batch_size, top_k] scores to [batch_size, vocab_size] - expanded_scores = torch.full(next_token_logits.shape, last_warper.filter_value) - expanded_scores.scatter(1, next_token_indices, next_token_scores) - scores += (expanded_scores,) - else: - scores += (next_token_scores,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # sample - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1) - if fast_topk: - # Convert the topk relative tokens to actual vocabulary tokens - next_tokens = torch.gather(next_token_indices, 1, next_tokens) - next_tokens = next_tokens.squeeze(1) + next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - # update generated ids, model inputs, and length for next step + # update inputs for the next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) + unfinished_sequences = unfinished_sequences * next_tokens.ne(eos_token_id) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - break + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + break # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): + if stopping_criteria(input_ids, None): break - if streamer is not None: - streamer.end() + return input_ids - if return_dict_in_generate: - return SampleDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) + def sample_token( + self, + scores, + logits_warper: LogitsProcessorList, + warper_indices: bool, + ) -> torch.LongTensor: + if warper_indices: + # Get [batch_size, top_k] scores and indices instead of [batch_size, vocab_size] scores + scores, next_token_indices = logits_warper(None, scores) else: - return input_ids + scores = logits_warper(None, scores) + + # sample + probs = torch.nn.functional.softmax(scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1) + if warper_indices: + # Convert the topk relative tokens to actual vocabulary tokens + next_tokens = torch.gather(next_token_indices, 1, next_tokens) + return next_tokens.squeeze(1) diff --git a/tests/inference/test_modeling_decoder.py b/tests/inference/test_modeling_decoder.py index 79ece3346..21a98342f 100644 --- a/tests/inference/test_modeling_decoder.py +++ b/tests/inference/test_modeling_decoder.py @@ -107,34 +107,44 @@ def test_model_from_hub(): _check_neuron_model(model) -def _test_model_generation(model, tokenizer, batch_size, length, **gen_kwargs): - prompt_text = "Hello, I'm a language model," - prompts = [prompt_text for _ in range(batch_size)] - tokens = tokenizer(prompts, return_tensors="pt") +def _test_model_generation(model, tokenizer, batch_size, input_length, **gen_kwargs): + input_ids = torch.ones((batch_size, input_length), dtype=torch.int64) with torch.inference_mode(): - sample_output = model.generate(**tokens, min_length=length, max_length=length, **gen_kwargs) + sample_output = model.generate(input_ids, **gen_kwargs) assert sample_output.shape[0] == batch_size - assert sample_output.shape[1] == length @pytest.mark.parametrize( - "gen_kwargs", [{"do_sample": True}, {"do_sample": True, "temperature": 0.7}], ids=["sample", "sample-with-temp"] + "gen_kwargs", + [ + {"do_sample": True}, + {"do_sample": True, "temperature": 0.7}, + {"do_sample": False}, + {"do_sample": False, "repetition_penalty": 1.2}, + ], + ids=["sample", "sample-with-temp", "greedy", "greedy_no-repeat"], ) @is_inferentia_test @requires_neuronx def test_model_generation(neuron_model_path, gen_kwargs): model = NeuronModelForCausalLM.from_pretrained(neuron_model_path) tokenizer = AutoTokenizer.from_pretrained(neuron_model_path) - # Using static model parameters - _test_model_generation(model, tokenizer, model.batch_size, model.max_length, **gen_kwargs) - # Using a lower max length - _test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2, **gen_kwargs) + _test_model_generation(model, tokenizer, model.batch_size, 10, **gen_kwargs) + + +@is_inferentia_test +@requires_neuronx +def test_model_generation_input_dimensions(neuron_model_path): + model = NeuronModelForCausalLM.from_pretrained(neuron_model_path) + tokenizer = AutoTokenizer.from_pretrained(neuron_model_path) + # Using valid input dimensions + _test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2) # Using an incompatible batch_size with pytest.raises(ValueError, match="The specified batch_size"): - _test_model_generation(model, tokenizer, model.batch_size + 1, model.max_length, **gen_kwargs) - # Using an incompatible generation length - with pytest.raises(ValueError, match="The current sequence length"): - _test_model_generation(model, tokenizer, model.batch_size, model.max_length * 2, **gen_kwargs) + _test_model_generation(model, tokenizer, model.batch_size + 1, model.max_length) + # Using an incompatible input length + with pytest.raises(ValueError, match="The input sequence length"): + _test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2) @is_inferentia_test