Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The get_ppl missed the last token of each iteration during multi-iter prefill #2499

Merged
merged 14 commits into from
Sep 26, 2024
31 changes: 0 additions & 31 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,34 +1030,3 @@ async def async_end(self, session_id: int):
def end(self, session_id: int):
"""Add new session."""
return self.engine_instance.end(session_id)

def decode(self,
input_ids,
input_embeddings: List[InputEmbeddingType] = None,
input_embedding_ranges: List[InputEmbeddingRangeType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.

Args:
input_ids (List[List[int]] | List[np.ndaray]): the batch of input
token ids
steps (List[int]): the offset of the k/v cache
input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
embeddings features
input_embedding_ranges: (List[List[Tuple[int, int]]]):
the begin/end offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
return self.engine_instance.decode(
input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end,
adapter_names=adapter_names)
105 changes: 59 additions & 46 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def get_logits(
for input_id in input_ids:
assert len(input_id) > 0

max_input_len = self.backend_config.max_prefill_token_num
bs = len(input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
Expand Down Expand Up @@ -173,79 +178,87 @@ def _split_embeddings(input_ids, niter, iter_len, embeddings,
logits = torch.cat(logits, dim=1)
return logits

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.
def get_ppl(
self, input_ids: Union[List[int],
List[List[int]]]) -> Union[float, List[float]]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.

Args:
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids

Returns:
Union[float, List[float]]: A list of perplexity scores.
"""
assert len(input_ids) > 0
assert isinstance(input_ids, List)
if isinstance(input_ids[0], int):
input_ids = [input_ids]
for input_id in input_ids:
assert len(input_id) > 1

max_input_len = self.backend_config.max_prefill_token_num
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
generator = self.engine.create_instance()

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)
bs = len(input_ids)
max_seq_len = max([len(_) for _ in input_ids])

generator = self.engine.create_instance()
all_loss_matrix = []
all_target_mask = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

losses = []
target_counts = []
for i in range(0, max_seq_len, max_input_len):
token_ids = [
input_id[i:i + max_input_len] for input_id in input_ids
]
_logits = generator.decode(_input_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
_logits = _logits.float().cpu()
steps = [i] * bs
logits = generator.decode(
input_ids=token_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= max_seq_len))
bsz, seq_len, vocab_size = logits.shape
logits = logits.float()
padding_token_id = -100
target_ids = [(x + [padding_token_id])[1:] for x in _input_ids]
# meaning logits[..., :, :] corresponds to labels
# token_ids[1:] + predict_token_id, which is
# input_ids[:, i+max_input_len:i+max_input_len+1]
target_ids = [
input_id[i + 1:i + 1 + max_input_len] for input_id in input_ids
]
target_ids = [
target_ids[i] + [padding_token_id]
if len(target_ids[i]) < len(token_ids[i]) else target_ids[i]
for i in range(bsz)
]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(_logits.device)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id
target_count = torch.sum(target_mask, dim=-1)

# compute cross entropy loss
bsz, seq_len, vocab_size = _logits.shape
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss = flat_loss_matrix.sum(dim=-1).cpu().view(bsz, -1)
target_count = target_mask.sum(dim=-1).cpu().view(bsz, -1)
losses.append(loss)
target_counts.append(target_count)

all_loss_matrix.append(flat_loss_matrix.view(bsz, seq_len))
all_target_mask.append(target_mask)
target_count = torch.concatenate(target_counts, dim=-1).sum(dim=-1)
loss_sum = torch.concatenate(losses, dim=-1).sum(dim=-1)

all_loss_matrix = torch.cat(all_loss_matrix, dim=1)
all_target_mask = torch.cat(all_target_mask, dim=1)
target_count = torch.sum(all_target_mask, dim=-1)
loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
loss_avg = loss_avg.numpy()

return loss_avg
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,9 @@ def weight_type(self):
def group_size(self):
return self.model_config.group_size

@property
def vocab_size(self):
return self.model_config.vocab_size

def __str__(self):
return json.dumps(self.to_dict(), indent=2)
92 changes: 7 additions & 85 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import asdict
from itertools import repeat
from queue import LifoQueue, Queue
from typing import Dict, Iterable, List, Union
from typing import Dict, Iterable, List

import numpy as np
import torch
Expand Down Expand Up @@ -314,7 +314,7 @@ def create_instance(self, cuda_stream_id=0):
Returns:
TurboMindInstance: an instance of turbomind
"""
return TurboMindInstance(self, cuda_stream_id)
return TurboMindInstance(self, self.config, cuda_stream_id)


class TurboMindInstance:
Expand All @@ -325,7 +325,10 @@ class TurboMindInstance:
cuda_stream_id(int): identity of a cuda stream
"""

def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
def __init__(self,
tm_model: TurboMind,
config: TurbomindModelConfig,
cuda_stream_id: int = 0):
self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id

Expand All @@ -343,6 +346,7 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
self.que = Queue()
self.executor: ThreadPoolExecutor = None
self.future = None
self.config = config

def _create_model_instance(self, device_id):
rank = self.node_id * self.gpu_count + device_id
Expand Down Expand Up @@ -922,85 +926,3 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
logits = outputs['logits']

return logits[:, :-1, :]

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.

Args:
input_ids (Union[List[int], List[List[int]]]): the batch of input token ids
""" # noqa 501

if len(input_ids) == 0:
input_ids = [[]]
if isinstance(input_ids[0], int):
input_ids = [input_ids]

max_input_len = 16 * 1024
# max_input_len = 16
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)

device = 'cpu' if n_max_iter > 1 else 'cuda'

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)

logits = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
]
_logits = self.decode(_input_ids,
steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
if _logits is None:
return None
_logits = _logits.to(device=device)
logits.append(_logits)

# concat logits. Shape is [bsz, seq_len, vocab_size]
logits = torch.cat(logits, dim=1)

# get target ids
padding_token_id = -100
target_ids = [(_input_ids + [padding_token_id])[1:]
for _input_ids in input_ids]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id
target_count = torch.sum(target_mask, dim=-1)

# compute cross entropy loss
bsz, seq_len, vocab_size = logits.shape
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)

loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss_sum = torch.sum(loss_matrix * target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
return loss_avg
Loading