From 5b55f37b871d9074e4d7cae9c58f746837aa0e80 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 19:38:38 -0700 Subject: [PATCH 1/6] remove some spawn --- .buildkite/test-pipeline.yaml | 9 ------- .../test_basic_distributed_correctness.py | 10 +++++--- .../test_chunked_prefill_distributed.py | 9 +++++-- tests/models/test_llava.py | 15 +++++++---- tests/models/test_phi3v.py | 25 +++++++++++-------- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 307ada611a859..f26e03fd8b145 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -45,9 +45,6 @@ steps: num_gpus: 2 commands: - bash ../.buildkite/download-images.sh - # FIXIT: find out which code initialize cuda before running the test - # before the fix, we need to use spawn to test it - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py @@ -71,9 +68,6 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - # FIXIT: find out which code initialize cuda before running the test - # before the fix, we need to use spawn to test it - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s distributed/test_pynccl.py # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. @@ -225,9 +219,6 @@ steps: gpu: a100 num_gpus: 4 commands: - # FIXIT: find out which code initialize cuda before running the test - # before the fix, we need to use spawn to test it - - export VLLM_WORKER_MULTIPROC_METHOD=spawn # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index b8ae5b4c44f8d..425d66230bc10 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -38,9 +38,10 @@ def test_models( ) -> None: distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, dtype=dtype, tensor_parallel_size=2, @@ -48,6 +49,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] vllm_output_ids, vllm_output_str = vllm_outputs[i] diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 4e4e468c4377a..f19253eba7cae 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -45,8 +45,10 @@ def test_models( enable_chunked_prefill = True max_num_batched_tokens = chunked_prefill_token_size - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). with vllm_runner( model, @@ -59,6 +61,9 @@ def test_models( ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] vllm_output_ids, vllm_output_str = vllm_outputs[i] diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index f2dfd4bb8596f..ec5c33819540a 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -89,16 +89,16 @@ def run_test( hf_images = [asset.for_hf() for asset in image_assets] vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] - with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: - hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images) - vllm_image_prompts = [ p.replace("", "" * vlm_config.image_feature_size) for p in HF_IMAGE_PROMPTS ] + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model_id, dtype=dtype, tensor_parallel_size=tensor_parallel_size, @@ -109,6 +109,11 @@ def run_test( max_tokens, images=vllm_images) + with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) + for i in range(len(HF_IMAGE_PROMPTS)): hf_output_ids, hf_output_str = hf_outputs[i] vllm_output_ids, vllm_output_str = vllm_to_hf_output( diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index e7d5639494104..c79b59d0d7d79 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -97,22 +97,17 @@ def run_test( hf_images = [asset.for_hf() for asset in image_assets] vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] - # use eager mode for hf runner, since phi3_v didn't work with flash_attn - hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model_id, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - hf_outputs = hf_model.generate_greedy( - HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images, - eos_token_id=hf_model.processor.tokenizer.eos_token_id) - vllm_image_prompts = [ p.replace("<|image_1|>", "<|image|>" * vlm_config.image_feature_size + "") for p in HF_IMAGE_PROMPTS ] + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model_id, max_model_len=2048, dtype=dtype, @@ -124,6 +119,16 @@ def run_test( max_tokens, images=vllm_images) + # use eager mode for hf runner, since phi3_v didn't work with flash_attn + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model_id, dtype=dtype, + model_kwargs=hf_model_kwargs) as hf_model: + hf_outputs = hf_model.generate_greedy( + HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images, + eos_token_id=hf_model.processor.tokenizer.eos_token_id) + for i in range(len(HF_IMAGE_PROMPTS)): hf_output_ids, hf_output_str = hf_outputs[i] vllm_output_ids, vllm_output_str = vllm_to_hf_output( From 4949d917d5b8e0a99114638bb3493983e4c9046e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 19:52:54 -0700 Subject: [PATCH 2/6] use cuda_device_count_stateless --- tests/distributed/test_basic_distributed_correctness.py | 5 +++-- tests/distributed/test_chunked_prefill_distributed.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 425d66230bc10..fe1d3e0929ab1 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -15,7 +15,8 @@ import os import pytest -import torch + +from vllm.utils import cuda_device_count_stateless MODELS = [ os.environ["TEST_DIST_MODEL"], @@ -23,7 +24,7 @@ DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" -@pytest.mark.skipif(torch.cuda.device_count() < 2, +@pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index f19253eba7cae..b0ff0deb47245 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -14,7 +14,8 @@ import os import pytest -import torch + +from vllm.utils import cuda_device_count_stateless MODELS = [ os.environ["TEST_DIST_MODEL"], @@ -22,7 +23,7 @@ DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" -@pytest.mark.skipif(torch.cuda.device_count() < 2, +@pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) From aaa3c9f9ea08b4c830b317fe1cc82d5ecfa241f8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 22:09:07 -0700 Subject: [PATCH 3/6] fix AutoProcessor --- tests/conftest.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index b429d8d0b5600..ff4510e5364f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from PIL import Image from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, - AutoProcessor, AutoTokenizer, BatchEncoding) + AutoTokenizer, BatchEncoding) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig @@ -216,6 +216,9 @@ def __init__( ) try: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, From 99f03968670645b48160898525b2e6e51f87b9b7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 22:11:34 -0700 Subject: [PATCH 4/6] fix MultiModalData --- tests/conftest.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ff4510e5364f0..d385467a5632b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, + TypedDict, TypeVar) import pytest import torch @@ -22,8 +22,12 @@ destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm.multimodal import MultiModalData -from vllm.multimodal.image import ImageFeatureData, ImagePixelData + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalData +else: + # it will call torch.cuda.device_count() + MultiModalData = None from vllm.sequence import SampleLogprobs from vllm.utils import cuda_device_count_stateless, is_cpu @@ -63,6 +67,10 @@ def for_hf(self) -> Image.Image: return self.pil_image def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from vllm.multimodal.image import (ImageFeatureData, # noqa: F401 + ImagePixelData) image_input_type = vision_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType From f566ae7bf9483a886a4c8ffbf16a4ea66af3d00b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 23:02:32 -0700 Subject: [PATCH 5/6] fix llava --- tests/conftest.py | 4 ++-- tests/models/test_llava.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d385467a5632b..0bd24905efab8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,8 +69,8 @@ def for_hf(self) -> Image.Image: def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData: # don't put this import at the top level # it will call torch.cuda.device_count() - from vllm.multimodal.image import (ImageFeatureData, # noqa: F401 - ImagePixelData) + from vllm.multimodal.image import ImageFeatureData # noqa: F401 + from vllm.multimodal.image import ImagePixelData image_input_type = vision_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index ec5c33819540a..6d375e66ee19c 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -87,12 +87,6 @@ def run_test( """ model_id, vlm_config = model_and_config hf_images = [asset.for_hf() for asset in image_assets] - vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] - - vllm_image_prompts = [ - p.replace("", "" * vlm_config.image_feature_size) - for p in HF_IMAGE_PROMPTS - ] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -105,6 +99,17 @@ def run_test( distributed_executor_backend=distributed_executor_backend, enforce_eager=True, **vlm_config.as_cli_args_dict()) as vllm_model: + + # NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` + # we must put it inside the vllm_runner context manager + # i.e. after creating vLLM instance. + vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] + + vllm_image_prompts = [ + p.replace("", "" * vlm_config.image_feature_size) + for p in HF_IMAGE_PROMPTS + ] + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, max_tokens, images=vllm_images) From 359fdcaae34c78285e7f126ec24b5acf85c28f20 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 29 Jun 2024 23:11:40 -0700 Subject: [PATCH 6/6] fix phi3v test --- .buildkite/test-pipeline.yaml | 3 +-- tests/models/test_phi3v.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f26e03fd8b145..dcc4a8c8e6648 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,8 +57,7 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - # FIXIT: find out why TP is failing with mp backend on phi3-v - # - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index c79b59d0d7d79..91fe7a218e67d 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -95,13 +95,6 @@ def run_test( """ model_id, vlm_config = model_and_config hf_images = [asset.for_hf() for asset in image_assets] - vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] - - vllm_image_prompts = [ - p.replace("<|image_1|>", - "<|image|>" * vlm_config.image_feature_size + "") - for p in HF_IMAGE_PROMPTS - ] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -115,6 +108,18 @@ def run_test( enforce_eager=True, distributed_executor_backend=distributed_executor_backend, **vlm_config.as_cli_args_dict()) as vllm_model: + # NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` + # we must put it inside the vllm_runner context manager + # i.e. after creating vLLM instance. + + vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] + + vllm_image_prompts = [ + p.replace("<|image_1|>", + "<|image|>" * vlm_config.image_feature_size + "") + for p in HF_IMAGE_PROMPTS + ] + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, max_tokens, images=vllm_images)