Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The get_ppl missed the last token of each iteration during multi-iter prefill #2499

Merged
merged 14 commits into from
Sep 26, 2024
31 changes: 0 additions & 31 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,34 +1030,3 @@ async def async_end(self, session_id: int):
def end(self, session_id: int):
"""Add new session."""
return self.engine_instance.end(session_id)

def decode(self,
input_ids,
input_embeddings: List[InputEmbeddingType] = None,
input_embedding_ranges: List[InputEmbeddingRangeType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.

Args:
input_ids (List[List[int]] | List[np.ndaray]): the batch of input
token ids
steps (List[int]): the offset of the k/v cache
input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
embeddings features
input_embedding_ranges: (List[List[Tuple[int, int]]]):
the begin/end offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
return self.engine_instance.decode(
input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end,
adapter_names=adapter_names)
229 changes: 163 additions & 66 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def get_logits(
for input_id in input_ids:
assert len(input_id) > 0

max_input_len = self.backend_config.max_prefill_token_num
bs = len(input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
Expand Down Expand Up @@ -173,79 +178,171 @@ def _split_embeddings(input_ids, niter, iter_len, embeddings,
logits = torch.cat(logits, dim=1)
return logits

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.
def get_ppl(self, input_ids: Union[List[int],
List[List[int]]]) -> List[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.

Args:
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids

Returns:
Union[float, List[float]]: A list of perplexity scores.
"""
assert len(input_ids) > 0
assert isinstance(input_ids, List)
if isinstance(input_ids[0], int):
input_ids = [input_ids]
for input_id in input_ids:
assert len(input_id) > 1

max_input_len = self.backend_config.max_prefill_token_num
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)

generator = self.engine.create_instance()
all_loss_matrix = []
all_target_mask = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
]
_logits = generator.decode(_input_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
_logits = _logits.float().cpu()
padding_token_id = -100
target_ids = [(x + [padding_token_id])[1:] for x in _input_ids]

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
losses = []
target_counts = []
sorted_index_values = sorted(list(enumerate(sizes)),
key=lambda x: x[1],
reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
if start == end:
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
loss, target_count = self._get_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len,
)
losses.append(loss)
target_counts.append(target_count)
loss = torch.concatenate(losses)
target_count = torch.concatenate(target_counts)
loss_avg = loss / target_count
loss_avg = loss_avg.numpy().tolist()
result = list(range(len(loss_avg)))
for index, sorted_index in enumerate(indices):
result[sorted_index] = loss_avg[index]
return result

def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",

here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(
sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1

yield (start_index, i)
if i > start_index:
continue
else:
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert isinstance(input_ids, List) and len(input_ids) == 1
seq_len = len(input_ids[0])
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[:, i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=token_ids,
max_input_len=max_input_len,
target_ids=target_ids,
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
losses.append(loss)
target_counts.append(target_count)
loss_sum = torch.concatenate(losses).sum().unsqueeze(0)
target_count = torch.concatenate(target_counts).sum().unsqueeze(0)
return loss_sum, target_count

def _get_ppl(self,
generator,
input_ids,
max_input_len,
target_ids=None,
steps=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert isinstance(input_ids, List)
assert all(isinstance(_, List) for _ in input_ids)
if target_ids:
assert all(isinstance(_, List) for _ in target_ids)

lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()
logits = generator.decode(input_ids=input_ids,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end)
bsz, seq_len, vocab_size = logits.shape
logits = logits.float()
padding_token_id = -100
if target_ids is None:
# shift token_ids by 1 to the left
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
target_ids[i] + [padding_token_id]
if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(bsz)
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(_logits.device)
target_mask = target_ids != padding_token_id
target_count = torch.sum(target_mask, dim=-1)
# compute cross entropy loss
bsz, seq_len, vocab_size = _logits.shape
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)

all_loss_matrix.append(flat_loss_matrix.view(bsz, seq_len))
all_target_mask.append(target_mask)

all_loss_matrix = torch.cat(all_loss_matrix, dim=1)
all_target_mask = torch.cat(all_target_mask, dim=1)
target_count = torch.sum(all_target_mask, dim=-1)
loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
return loss_avg
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss = flat_loss_matrix.sum(dim=-1).cpu()
target_count = target_mask.sum(dim=-1).cpu()
return loss, target_count
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,9 @@ def weight_type(self):
def group_size(self):
return self.model_config.group_size

@property
def vocab_size(self):
return self.model_config.vocab_size

def __str__(self):
return json.dumps(self.to_dict(), indent=2)
Loading
Loading