From 0faab90eb006c677add65cd4c2d0f740a63e064d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 20 Sep 2024 19:55:33 -0700 Subject: [PATCH] [beam search] add output for manually checking the correctness (#8684) --- tests/samplers/test_beam_search.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 64f3ce94b7a83..98a02dec895d2 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -11,7 +11,7 @@ # 3. Use the model "huggyllama/llama-7b". MAX_TOKENS = [128] BEAM_WIDTHS = [4] -MODELS = ["facebook/opt-125m"] +MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @pytest.mark.parametrize("model", MODELS) @@ -37,8 +37,15 @@ def test_beam_search_single_input( beam_width, max_tokens) for i in range(len(example_prompts)): - hf_output_ids, _ = hf_outputs[i] - vllm_output_ids, _ = vllm_outputs[i] + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + for i, (hf_text, + vllm_text) in enumerate(zip(hf_output_texts, + vllm_output_texts)): + print(f">>>{i}-th hf output:") + print(hf_text) + print(f">>>{i}-th vllm output:") + print(vllm_text) assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], (