Skip to content

Commit

Permalink
Update prompt speed test
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Apr 18, 2024
1 parent aef7bd1 commit dc1dfc4
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,22 +467,36 @@ def test_ppl_token():

print(f" -- Measuring prompt speed...")

torch.cuda.synchronize()

current_len = 128
step = 128
prompt_iters = 3
while True:

time_begin = time.time()
total_time = 0
for i in range(prompt_iters):

cache.current_seq_len = 0
model.forward(ids[:, :current_len], cache, preprocess_only = True)
torch.cuda.synchronize()
torch.cuda.synchronize()
time_begin = time.time()

time_end = time.time()
tps = current_len / (time_end - time_begin)
cache.current_seq_len = 0
model.forward(ids[:, :current_len], cache, preprocess_only = True)

torch.cuda.synchronize()
time_end = time.time()
total_time += time_end - time_begin

tps = current_len / (total_time / prompt_iters)

print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")

if current_len >= 1024: step = 1024
if current_len >= 4096: step = 4096
if current_len >= 16384: step = 8192

current_len_ = current_len
current_len = min(current_len + 128, model.config.max_seq_len)
current_len = min(current_len + step, model.config.max_seq_len)
if current_len == current_len_: break


Expand Down

0 comments on commit dc1dfc4

Please sign in to comment.