Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tnx] fix optimum token selection and sampling #2233

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import PretrainedConfig
from transformers_neuronx import bucket
from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB
from optimum.neuron.generation import TokenSelector
from djl_python.transformers_neuronx_scheduler.optimum_token_selector import OptimumTokenSelector
from optimum.neuron.utils.version_utils import check_compiler_compatibility, get_neuronxcc_version
from optimum.modeling_base import OptimizedModel
from transformers.generation import StoppingCriteriaList
Expand Down Expand Up @@ -238,11 +238,12 @@ def generate(
self._validate_model_kwargs(model_kwargs)

# Instantiate a TokenSelector for the specified configuration
selector = TokenSelector.create(input_ids,
generation_config,
self,
self.max_length,
stopping_criteria=stopping_criteria)
selector = OptimumTokenSelector.create(
input_ids,
generation_config,
self,
self.max_length,
stopping_criteria=stopping_criteria)

# Verify that the inputs are compatible with the model static input dimensions
batch_size, sequence_length = input_ids.shape
Expand Down Expand Up @@ -280,7 +281,7 @@ def generate(
def generate_tokens(
self,
input_ids: torch.LongTensor,
selector: TokenSelector,
selector: OptimumTokenSelector,
batch_size: int,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs,
Expand All @@ -291,7 +292,7 @@ def generate_tokens(
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
selector (`TokenSelector`):
selector (`OptimumTokenSelector`):
The object implementing the generation logic based on transformers processors and stopping criterias.
batch_size (`int`):
The actual input batch size. Used to avoid generating tokens for padded inputs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from dataclasses import dataclass

from djl_python.transformers_neuronx_scheduler.slot import Slot
from djl_python.rolling_batch.rolling_batch import filter_unused_generation_params
from djl_python.request import Request
from djl_python.transformers_neuronx_scheduler.token_selector import TokenSelector
from djl_python.transformers_neuronx_scheduler.speculation import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
# The below code is heavily inspired from Optimum Neuron under the following link:
# https://github.com/huggingface/optimum-neuron/blob/main/optimum/neuron/generation/token_selector.py

import copy
import logging
from typing import TYPE_CHECKING, List, Optional

import torch
from transformers.generation import (
GenerationConfig,
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.generation.utils import GenerationMode

from optimum.neuron.generation import FusedLogitsWarper

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

logger = logging.getLogger(__name__)


# TODO: This is a temporary solution to avoid Optimum's dependency on transformers<4.42.
class OptimumTokenSelector:
"""Implements the token selection logic corresponding to a generation configuration.

This class combines and uses the logits processors and stopping criterias implemented in
the transformers library.

The algorithm to select these objects is heavily inspired by the transformers `GenerationMixin.generate()`
method, but the actual token selection methods are specific.

The reason why this class does not inherit from `GenerationMixin` is because it does not
include the code to produce the tokens logits.
Separating the production of the tokens logits from the tokens selection allows this class
to be used with different generation paradigms, either synchronously using a single `TokenSelector` in
`GenerationMixin.generate()` or asynchronously using multiple `TokenSelector` inside an inference endpoint.

The constructor of this class should not be called directly: instances should be obtained by
calling `TokenSelector.create()`.
"""

def __init__(
self,
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_ids: List[int],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
self.generator = torch.Generator()
self.generator.manual_seed(seed)

@classmethod
def create(
cls,
input_ids: torch.Tensor,
generation_config: GenerationConfig,
model: GenerationMixin,
max_seq_length: int,
stopping_criteria: Optional[StoppingCriteriaList] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
seed: Optional[int] = 0,
) -> "OptimumTokenSelector":
r"""Creates the `TokenSelector` for a specific generation configuration.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
The generation configuration to parametrize the token selection.
model (`~transformers.generation.GenerationMixin`):
The model provides the internal helpers allowing to select the logits processors and stopping criterias.
max_seq_length (`int`):
The maximum number of input + generated tokens for this model. It depends on the model compilation parameters.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config
tokenizer (`Optional[transformers.PreTrainedTokenizer]`, default to `None`):
A tokenizer used when stop strings are passed to generate.
seed(`Optional[int]`):
The optional seed for sampling. Defaults to zero.
Return:
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
generation_config.validate()
generation_config = copy.deepcopy(generation_config)

unsupported_generation_flags = [
"output_attentions",
"output_hidden_states",
"output_scores",
"return_dict_in_generate",
]
for flag in unsupported_generation_flags:
if getattr(generation_config, flag, False):
raise ValueError("{flag} is not supported for generation.")

if generation_config.max_new_tokens is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[
-1]

min_length = generation_config.min_length
if min_length > max_seq_length:
raise ValueError(
f"The minimum generation length ({min_length}) exceeds the model maximum sequence length ({max_seq_length})"
)
max_length = generation_config.max_length
if max_length > max_seq_length:
logger.warning(
f"Adjusting the maximum generation length ({max_length}) to the model maximum sequence length ({max_seq_length})"
)
generation_config.max_length = max_seq_length

# This is not supposed to happen for any of the models we support
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id,
list) else [eos_token_id]
generation_config._eos_token_tensor = torch.tensor(
eos_token_ids, device=input_ids.device)
if generation_config.pad_token_id is None:
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-ended generation."
)
generation_config.pad_token_id = eos_token_ids[0]

# Instantiate transformers library processors and criterias
logits_processor = model._get_logits_processor(
generation_config,
input_ids_seq_length=input_ids.shape[-1],
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
)
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(
generation_config,
stopping_criteria=stopping_criteria,
tokenizer=tokenizer)

generation_mode = generation_config.get_generation_mode()
if generation_mode not in [
GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE
]:
raise ValueError("Unsupported generation mode")

logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
logits_warper = FusedLogitsWarper.from_config(generation_config)

return cls(
mode=generation_mode,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)

def select(self, input_ids: torch.LongTensor,
logits: torch.Tensor) -> torch.LongTensor:
"""Select the next tokens from the candidate logits.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation (not used in all generation modes).
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The logits corresponding to the generated tokens.

Return:
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
scores = self.logits_processor(input_ids, logits)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
else:
return torch.argmax(scores, dim=-1)

def _sample(self, scores: torch.Tensor) -> torch.LongTensor:
# Get [batch_size, kept] scores and indices instead of [batch_size, vocab_size] scores
scores, next_token_indices = self.logits_warper(scores)

# sample
probs = torch.nn.functional.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs,
num_samples=1,
generator=self.generator)
# Convert the filtered tokens to actual vocabulary tokens
next_tokens = torch.gather(next_token_indices, 1, next_tokens)
return next_tokens.squeeze(1)
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def create(

logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
logits_warper = model._get_logits_warper(generation_config)
logits_warper = model._get_logits_warper(generation_config,
device=model.device)
if len(logits_warper) == 0:
generation_mode = GenerationMode.GREEDY_SEARCH

Expand Down
Loading