diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 414045fe163e5..d1200ee84dfe4 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py +# Run the image +docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test + +# offline inference +docker exec cpu-test bash -c "python3 examples/offline_inference.py" + +# Run basic model test +docker exec cpu-test bash -c "cd tests; + pip install pytest Pillow protobuf + bash ../.buildkite/download-images.sh + cd ../ + pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 265833e2ccf6e..7e986c988407c 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -40,6 +40,8 @@ steps: - label: "Intel Test" depends_on: ~ + agents: + queue: intel command: bash .buildkite/run-cpu-test.sh {% for step in steps %} diff --git a/Dockerfile.cpu b/Dockerfile.cpu index aec79824213f3..ae23e27b413ba 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,6 +1,6 @@ # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. -FROM ubuntu:22.04 +FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ @@ -9,6 +9,8 @@ RUN apt-get update -y \ RUN pip install --upgrade pip \ && pip install wheel packaging ninja setuptools>=49.4.0 numpy +FROM cpu-test-1 AS build + COPY ./ /workspace/vllm WORKDIR /workspace/vllm @@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + CMD ["/bin/bash"] diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 73bf77e46f538..e8aead17ae5a7 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -21,73 +21,74 @@ void rotary_embedding_impl( constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); const int embed_dim = rot_dim / 2; - TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + bool flag = (embed_dim % VEC_ELEM_NUM == 0); + const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, + scalar_t* qk) { + int j = 0; + for (; j < loop_upper; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; - for (int i = 0; i < num_heads; ++i) { - const int head_idx = i; - const int64_t token_head = - token_idx * query_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); + const scalar_vec_t q_x(qk + out_x); + const scalar_vec_t q_y(qk + out_y); - const scalar_vec_t q_x(query + out_x); - const scalar_vec_t q_y(query + out_y); + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); - vec_op::FP32Vec8 fp32_q_x(q_x); - vec_op::FP32Vec8 fp32_q_y(q_y); + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(qk + out_x); - auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - scalar_vec_t(out1).save(query + out_x); - - auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - scalar_vec_t(out2).save(query + out_y); - } + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(qk + out_y); } - - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; + if (!flag) { + for (; j < embed_dim; ++j) { + const int x_index = j; + const int y_index = embed_dim + j; const int64_t out_x = token_head + x_index; const int64_t out_y = token_head + y_index; - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); + const float fp32_cos = cache_ptr[x_index]; + const float fp32_sin = cache_ptr[y_index]; - const scalar_vec_t k_x(key + out_x); - const scalar_vec_t k_y(key + out_y); + const float fp32_q_x = qk[out_x]; + const float fp32_q_y = qk[out_y]; - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); + qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + } + } + }; - vec_op::FP32Vec8 fp32_k_x(k_x); - vec_op::FP32Vec8 fp32_k_y(k_y); +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; - scalar_vec_t(out1).save(key + out_x); - auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; - scalar_vec_t(out2).save(key + out_y); - } + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, query); + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); } } } diff --git a/tests/conftest.py b/tests/conftest.py index 796f498bb28a8..8fcd91305e3a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu logger = init_logger(__name__) @@ -60,7 +61,8 @@ def cleanup(): with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() - torch.cuda.empty_cache() + if not is_cpu(): + torch.cuda.empty_cache() @pytest.fixture() @@ -153,6 +155,12 @@ def example_long_prompts() -> List[str]: class HfRunner: + def wrap_device(self, input: any): + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + def __init__( self, model_name: str, @@ -167,17 +175,18 @@ def __init__( if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer( - model_name, - device="cpu", - ).to(dtype=torch_dtype).cuda() + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype)) else: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - token=access_token, - ).cuda() + self.model = self.wrap_device( + AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + )) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -218,7 +227,7 @@ def generate( inputs = self.processor(**processor_kwargs) output_ids = self.model.generate( - **inputs.to("cuda"), + **self.wrap_device(inputs), use_cache=True, **kwargs, ) @@ -275,7 +284,7 @@ def generate_greedy_logprobs( for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -310,7 +319,7 @@ def generate_greedy_logprobs_limit( for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py index a7abc011f57d7..85d74f7f5b03d 100644 --- a/tests/models/test_aqlm.py +++ b/tests/models/test_aqlm.py @@ -8,10 +8,13 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -aqlm_not_supported = (capability < - QUANTIZATION_METHODS["aqlm"].get_min_capability()) +aqlm_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) # In this test we hardcode prompts and generations for the model so we don't # need to require the AQLM package as a dependency diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 8116b796287a5..fd1253f73c93f 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -8,6 +8,7 @@ import sys import pytest +import torch MODELS = [ "meta-llama/Llama-2-7b-hf", @@ -36,9 +37,14 @@ "mosaicml/mpt-7b", ] +#TODO: remove this after CPU float16 support ready +target_dtype = "float" +if torch.cuda.is_available(): + target_dtype = "half" + @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [32]) def test_models( hf_runner, @@ -78,7 +84,7 @@ def test_models( @pytest.mark.skip("Slow and not useful (just prints model).") @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) def test_model_print( vllm_runner, model: str, diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 0a5819ea3f054..61aee0d0a6e93 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -67,10 +67,13 @@ }, } -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -fp8_not_supported = (capability < - QUANTIZATION_METHODS["fp8"].get_min_capability()) +fp8_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) @pytest.mark.skipif(fp8_not_supported, diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 561d4a1756587..da549cae0054f 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -21,10 +21,13 @@ MAX_MODEL_LEN = 1024 -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -gptq_marlin_not_supported = ( - capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) +gptq_marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) MODELS = [ # act_order==False, group_size=channelwise diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py index 3e6ffb7f90fcc..cc35ee803ff01 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/test_gptq_marlin_24.py @@ -14,10 +14,13 @@ from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index d3770fa69f6f1..585c5ad686d16 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -23,10 +23,15 @@ from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +from .utils import check_logprobs_close + +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass