diff --git a/test_inference.py b/test_inference.py index f7aa1259..c618d99e 100644 --- a/test_inference.py +++ b/test_inference.py @@ -467,22 +467,36 @@ def test_ppl_token(): print(f" -- Measuring prompt speed...") + torch.cuda.synchronize() + current_len = 128 + step = 128 + prompt_iters = 3 while True: - time_begin = time.time() + total_time = 0 + for i in range(prompt_iters): - cache.current_seq_len = 0 - model.forward(ids[:, :current_len], cache, preprocess_only = True) - torch.cuda.synchronize() + torch.cuda.synchronize() + time_begin = time.time() - time_end = time.time() - tps = current_len / (time_end - time_begin) + cache.current_seq_len = 0 + model.forward(ids[:, :current_len], cache, preprocess_only = True) + + torch.cuda.synchronize() + time_end = time.time() + total_time += time_end - time_begin + + tps = current_len / (total_time / prompt_iters) print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s") + if current_len >= 1024: step = 1024 + if current_len >= 4096: step = 4096 + if current_len >= 16384: step = 8192 + current_len_ = current_len - current_len = min(current_len + 128, model.config.max_seq_len) + current_len = min(current_len + step, model.config.max_seq_len) if current_len == current_len_: break