Skip to content

Commit

Permalink
[ci] add trtllm chat test
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jun 25, 2024
1 parent e6ff687 commit c0782ef
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,15 @@ def get_model_name():
}
}

trtllm_chat_model_spec = {
"llama2-7b-chat": {
"max_memory_per_gpu": [25.0],
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "TheBloke/Llama-2-7B-Chat-fp16"
}
}

no_code_rolling_batch_spec = {
"llama-7b": {
"max_memory_per_gpu": [25.0],
Expand Down Expand Up @@ -1286,6 +1295,8 @@ def run(raw_args):
test_handler_rolling_batch(args.model, lmi_dist_aiccl_model_spec)
elif args.handler == "trtllm":
test_handler_rolling_batch(args.model, trtllm_model_spec)
elif args.handler == "trtllm_chat":
test_handler_rolling_batch_chat(args.model, trtllm_chat_model_spec)
elif args.handler == "no_code":
test_handler_rolling_batch(args.model, no_code_rolling_batch_spec)

Expand Down
6 changes: 6 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,12 @@
"option.use_custom_all_reduce": False,
"option.max_rolling_batch_size": 32,
"option.output_formatter": "jsonlines"
},
"llama2-7b-chat": {
"option.model_id": "s3://djl-llm/meta-llama-Llama-2-7b-chat-hf/",
"option.dtype": "fp16",
"option.tensor_parallel_degree": 4,
"option.max_rolling_batch_size": 4
}
}

Expand Down
6 changes: 6 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def test_qwen_7b(self):
r.launch("CUDA_VISIBLE_DEVICES=0,1,2,3")
client.run("trtllm qwen-7b".split())

def test_llama2_7b_chat(self):
with Runner('tensorrt-llm', 'llama2-7b-chat') as r:
prepare.build_trtllm_handler_model("llama2-7b-chat")
r.launch("CUDA_VISIBLE_DEVICES=0,1,2,3")
client.run("trtllm_chat llama2-7b-chat".split())


class TestSchedulerSingleGPU:
# Runs on g5.12xl
Expand Down

0 comments on commit c0782ef

Please sign in to comment.