Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The get_ppl missed the last token of each iteration during multi-iter prefill #2499

Merged
merged 14 commits into from
Sep 26, 2024
49 changes: 23 additions & 26 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,42 +180,39 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids
"""
assert len(input_ids) > 0
assert isinstance(input_ids, List) and len(input_ids) > 0
if isinstance(input_ids[0], int):
input_ids = [input_ids]
for input_id in input_ids:
assert len(input_id) > 1
assert all(len(_) > 1 for _ in input_ids)

max_input_len = self.backend_config.max_prefill_token_num
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
bs = len(input_ids)
max_seq_len = max([len(input_id) for input_id in input_ids])

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)
# TODO: a better way to determine `max_input_len`
# At most allocate 2G mem for logits with shape [bs, seq, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

generator = self.engine.create_instance()
all_loss_matrix = []
all_target_mask = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
# suppose input_ids is [0,1,2,3,4,5,6,7,8], and max_input_len=5
# In the first iter, tokens [0,1,2,3,4] are prefilled.
# loss=cross_entropy(logits[..., :-1, :], token_ids[1,2,3,4])
# In the 2nd iter, token [4,5,6,7,8] should be prefilled.
# The first token must be the latest one in prev iter, because
# token_ids (or labels) have to be shifted the mostleft token
# loss=cross_entropy(logits[..., :-1, :], token_ids[5,6,7,8])
for i in range(0, max_seq_len, max_input_len - 1):
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
input_id[i:i + max_input_len] for input_id in input_ids
]
_logits = generator.decode(_input_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
steps = [i] * bs
_logits = generator.decode(
_input_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= max_seq_len))
_logits = _logits.float().cpu()
padding_token_id = -100
target_ids = [(x + [padding_token_id])[1:] for x in _input_ids]
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,9 @@ def weight_type(self):
def group_size(self):
return self.model_config.group_size

@property
def vocab_size(self):
return self.model_config.vocab_size

def __str__(self):
return json.dumps(self.to_dict(), indent=2)
Loading